diff --git a/Cargo.lock b/Cargo.lock index 89351432c1..588a63b6a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4236,6 +4236,8 @@ name = "pagebench" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "bytes", "camino", "clap", "futures", @@ -4244,12 +4246,15 @@ dependencies = [ "humantime-serde", "pageserver_api", "pageserver_client", + "pageserver_page_api", "rand 0.8.5", "reqwest", "serde", "serde_json", "tokio", + "tokio-stream", "tokio-util", + "tonic 0.13.1", "tracing", "utils", "workspace_hack", @@ -4305,6 +4310,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", + "http 1.1.0", "http-utils", "humantime", "humantime-serde", @@ -4367,6 +4373,7 @@ dependencies = [ "toml_edit", "tonic 0.13.1", "tonic-reflection", + "tower 0.5.2", "tracing", "tracing-utils", "twox-hash", @@ -4463,7 +4470,6 @@ dependencies = [ "pageserver_api", "postgres_ffi", "prost 0.13.5", - "smallvec", "thiserror 1.0.69", "tonic 0.13.1", "tonic-build", diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 9d4c93e1cd..f97f04968e 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \ + cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \ + cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \ + cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \ + cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \ + cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \ + cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 3459983a34..248f52088b 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -297,6 +297,7 @@ RUN ./autogen.sh && \ ./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \ make -j $(getconf _NPROCESSORS_ONLN) && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ + make staged-install && \ cd extensions/postgis && \ make clean && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ @@ -602,7 +603,7 @@ RUN case "${PG_VERSION:?}" in \ ;; \ esac && \ wget https://github.com/knizhnik/online_advisor/archive/refs/tags/1.0.tar.gz -O online_advisor.tar.gz && \ - echo "059b7d9e5a90013a58bdd22e9505b88406ce05790675eb2d8434e5b215652d54 online_advisor.tar.gz" | sha256sum --check && \ + echo "37dcadf8f7cc8d6cc1f8831276ee245b44f1b0274f09e511e47a67738ba9ed0f online_advisor.tar.gz" | sha256sum --check && \ mkdir online_advisor-src && cd online_advisor-src && tar xzf ../online_advisor.tar.gz --strip-components=1 -C . FROM pg-build AS online_advisor-build @@ -1180,14 +1181,14 @@ RUN cd exts/rag && \ RUN cd exts/rag_bge_small_en_v15 && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control RUN cd exts/rag_jina_reranker_v1_tiny_en && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control @@ -1842,10 +1843,25 @@ RUN make PG_VERSION="${PG_VERSION:?}" -C compute FROM pg-build AS extension-tests ARG PG_VERSION +# This is required for the PostGIS test +RUN apt-get update && case $DEBIAN_VERSION in \ + bullseye) \ + apt-get install -y libproj19 libgdal28 time; \ + ;; \ + bookworm) \ + apt-get install -y libgdal32 libproj25 time; \ + ;; \ + *) \ + echo "Unknown Debian version ${DEBIAN_VERSION}" && exit 1 \ + ;; \ + esac + COPY docker-compose/ext-src/ /ext-src/ COPY --from=pg-build /postgres /postgres -#COPY --from=postgis-src /ext-src/ /ext-src/ +COPY --from=postgis-build /usr/local/pgsql/ /usr/local/pgsql/ +COPY --from=postgis-build /ext-src/postgis-src /ext-src/postgis-src +COPY --from=postgis-build /sfcgal/* /usr COPY --from=plv8-src /ext-src/ /ext-src/ COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src COPY --from=postgresql-unit-src /ext-src/ /ext-src/ @@ -1886,6 +1902,7 @@ COPY compute/patches/pg_repack.patch /ext-src RUN cd /ext-src/pg_repack-src && patch -p1 /etc/ld.so.conf.d/00-neon.conf && /sbin/ldconfig RUN apt-get update && apt-get install -y libtap-parser-sourcehandler-pgtap-perl jq \ && apt clean && rm -rf /ext-src/*.tar.gz /ext-src/*.patch /var/lib/apt/lists/* ENV PATH=/usr/local/pgsql/bin:$PATH diff --git a/compute/manifest.yaml b/compute/manifest.yaml new file mode 100644 index 0000000000..f1cd20c497 --- /dev/null +++ b/compute/manifest.yaml @@ -0,0 +1,121 @@ +pg_settings: + # Common settings for primaries and replicas of all versions. + common: + # Check for client disconnection every 1 minute. By default, Postgres will detect the + # loss of the connection only at the next interaction with the socket, when it waits + # for, receives or sends data, so it will likely waste resources till the end of the + # query execution. There should be no drawbacks in setting this for everyone, so enable + # it by default. If anyone will complain, we can allow editing it. + # https://www.postgresql.org/docs/16/runtime-config-connection.html#GUC-CLIENT-CONNECTION-CHECK-INTERVAL + client_connection_check_interval: "60000" # 1 minute + # ---- IO ---- + effective_io_concurrency: "20" + maintenance_io_concurrency: "100" + fsync: "off" + hot_standby: "off" + # We allow users to change this if needed, but by default we + # just don't want to see long-lasting idle transactions, as they + # prevent activity monitor from suspending projects. + idle_in_transaction_session_timeout: "300000" # 5 minutes + listen_addresses: "*" + # --- LOGGING ---- helps investigations + log_connections: "on" + log_disconnections: "on" + # 1GB, unit is KB + log_temp_files: "1048576" + # Disable dumping customer data to logs, both to increase data privacy + # and to reduce the amount the logs. + log_error_verbosity: "terse" + log_min_error_statement: "panic" + max_connections: "100" + # --- WAL --- + # - flush lag is the max amount of WAL that has been generated but not yet stored + # to disk in the page server. A smaller value means less delay after a pageserver + # restart, but if you set it too small you might again need to slow down writes if the + # pageserver cannot flush incoming WAL to disk fast enough. This must be larger + # than the pageserver's checkpoint interval, currently 1 GB! Otherwise you get a + # a deadlock where the compute node refuses to generate more WAL before the + # old WAL has been uploaded to S3, but the pageserver is waiting for more WAL + # to be generated before it is uploaded to S3. + max_replication_flush_lag: "10GB" + max_replication_slots: "10" + # Backpressure configuration: + # - write lag is the max amount of WAL that has been generated by Postgres but not yet + # processed by the page server. Making this smaller reduces the worst case latency + # of a GetPage request, if you request a page that was recently modified. On the other + # hand, if this is too small, the compute node might need to wait on a write if there is a + # hiccup in the network or page server so that the page server has temporarily fallen + # behind. + # + # Previously it was set to 500 MB, but it caused compute being unresponsive under load + # https://github.com/neondatabase/neon/issues/2028 + max_replication_write_lag: "500MB" + max_wal_senders: "10" + # A Postgres checkpoint is cheap in storage, as doesn't involve any significant amount + # of real I/O. Only the SLRU buffers and some other small files are flushed to disk. + # However, as long as we have full_page_writes=on, page updates after a checkpoint + # include full-page images which bloats the WAL. So may want to bump max_wal_size to + # reduce the WAL bloating, but at the same it will increase pg_wal directory size on + # compute and can lead to out of disk error on k8s nodes. + max_wal_size: "1024" + wal_keep_size: "0" + wal_level: "replica" + # Reduce amount of WAL generated by default. + wal_log_hints: "off" + # - without wal_sender_timeout set we don't get feedback messages, + # required for backpressure. + wal_sender_timeout: "10000" + # We have some experimental extensions, which we don't want users to install unconsciously. + # To install them, users would need to set the `neon.allow_unstable_extensions` setting. + # There are two of them currently: + # - `pgrag` - https://github.com/neondatabase-labs/pgrag - extension is actually called just `rag`, + # and two dependencies: + # - `rag_bge_small_en_v15` + # - `rag_jina_reranker_v1_tiny_en` + # - `pg_mooncake` - https://github.com/Mooncake-Labs/pg_mooncake/ + neon.unstable_extensions: "rag,rag_bge_small_en_v15,rag_jina_reranker_v1_tiny_en,pg_mooncake,anon" + neon.protocol_version: "3" + password_encryption: "scram-sha-256" + # This is important to prevent Postgres from trying to perform + # a local WAL redo after backend crash. It should exit and let + # the systemd or k8s to do a fresh startup with compute_ctl. + restart_after_crash: "off" + # By default 3. We have the following persistent connections in the VM: + # * compute_activity_monitor (from compute_ctl) + # * postgres-exporter (metrics collector; it has 2 connections) + # * sql_exporter (metrics collector; we have 2 instances [1 for us & users; 1 for autoscaling]) + # * vm-monitor (to query & change file cache size) + # i.e. total of 6. Let's reserve 7, so there's still at least one left over. + superuser_reserved_connections: "7" + synchronous_standby_names: "walproposer" + + replica: + hot_standby: "on" + + per_version: + 17: + common: + # PostgreSQL 17 has a new IO system called "read stream", which can combine IOs up to some + # size. It still has some issues with readahead, though, so we default to disabled/ + # "no combining of IOs" to make sure we get the maximum prefetch depth. + # See also: https://github.com/neondatabase/neon/pull/9860 + io_combine_limit: "1" + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 16: + common: + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 15: + common: + replica: + # prefetching of blocks referenced in WAL doesn't make sense for us + # Neon hot standby ignores pages that are not in the shared_buffers + recovery_prefetch: "off" + 14: + common: + replica: diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 02339f752c..8b502a058e 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -40,7 +40,7 @@ use std::sync::mpsc; use std::thread; use std::time::Duration; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use clap::Parser; use compute_api::responses::ComputeConfig; use compute_tools::compute::{ @@ -57,31 +57,15 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -// Compatibility hack: if the control plane specified any remote-ext-config -// use the default value for extension storage proxy gateway. -// Remove this once the control plane is updated to pass the gateway URL -fn parse_remote_ext_base_url(arg: &str) -> Result { - const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str = - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"; - - Ok(if arg.starts_with("http") { - arg - } else { - FALLBACK_PG_EXT_GATEWAY_BASE_URL - } - .to_owned()) -} - -#[derive(Parser)] +#[derive(Debug, Parser)] #[command(rename_all = "kebab-case")] struct Cli { #[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")] pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")] - pub remote_ext_base_url: Option, + #[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)] + pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running /// outside the compute will talk to the compute through this port. Keep @@ -142,6 +126,25 @@ struct Cli { pub installed_extensions_collection_interval: u64, } +impl Cli { + /// Parse a URL from an argument. By default, this isn't necessary, but we + /// want to do some sanity checking. + fn parse_remote_ext_base_url(value: &str) -> Result { + // Remove extra trailing slashes, and add one. We use Url::join() later + // when downloading remote extensions. If the base URL is something like + // http://example.com/pg-ext-s3-gateway, and join() is called with + // something like "xyz", the resulting URL is http://example.com/xyz. + let value = value.trim_end_matches('/').to_owned() + "/"; + let url = Url::parse(&value)?; + + if url.query_pairs().count() != 0 { + bail!("parameters detected in remote extensions base URL") + } + + Ok(url) + } +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -268,7 +271,8 @@ fn handle_exit_signal(sig: i32) { #[cfg(test)] mod test { - use clap::CommandFactory; + use clap::{CommandFactory, Parser}; + use url::Url; use super::Cli; @@ -278,16 +282,41 @@ mod test { } #[test] - fn parse_pg_ext_gateway_base_url() { - let arg = "http://pg-ext-s3-gateway2"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!(result, arg); - - let arg = "pg-ext-s3-gateway"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); + fn verify_remote_ext_base_url() { + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com/subpath", + ]); assert_eq!( - result, - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local" + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com/subpath/").unwrap() ); + + let cli = Cli::parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com//", + ]); + assert_eq!( + cli.remote_ext_base_url.unwrap(), + Url::parse("https://example.com").unwrap() + ); + + Cli::try_parse_from([ + "compute_ctl", + "--pgdata=test", + "--connstr=test", + "--compute-id=test", + "--remote-ext-base-url", + "https://example.com?hello=world", + ]) + .expect_err("URL parameters are not allowed"); } } diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ff49c737f0..bd6ed910be 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -3,7 +3,7 @@ use chrono::{DateTime, Utc}; use compute_api::privilege::Privilege; use compute_api::responses::{ ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus, LfcOffloadState, - LfcPrewarmState, + LfcPrewarmState, TlsConfig, }; use compute_api::spec::{ ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent, @@ -31,6 +31,7 @@ use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::spawn; use tracing::{Instrument, debug, error, info, instrument, warn}; +use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; @@ -96,7 +97,7 @@ pub struct ComputeNodeParams { pub internal_http_port: u16, /// the address of extension storage proxy gateway - pub remote_ext_base_url: Option, + pub remote_ext_base_url: Option, /// Interval for installed extensions collection pub installed_extensions_collection_interval: u64, @@ -395,7 +396,7 @@ impl ComputeNode { // because QEMU will already have its memory allocated from the host, and // the necessary binaries will already be cached. if cli_spec.is_none() { - this.prewarm_postgres()?; + this.prewarm_postgres_vm_memory()?; } // Set the up metric with Empty status before starting the HTTP server. @@ -602,6 +603,8 @@ impl ComputeNode { }); } + let tls_config = self.tls_config(&pspec.spec); + // If there are any remote extensions in shared_preload_libraries, start downloading them if pspec.spec.remote_extensions.is_some() { let (this, spec) = (self.clone(), pspec.spec.clone()); @@ -658,7 +661,7 @@ impl ComputeNode { info!("tuning pgbouncer"); let pgbouncer_settings = pgbouncer_settings.clone(); - let tls_config = self.compute_ctl_config.tls.clone(); + let tls_config = tls_config.clone(); // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. @@ -677,7 +680,10 @@ impl ComputeNode { // Spawn a background task to do the configuration, // so that we don't block the main thread that starts Postgres. - let local_proxy = local_proxy.clone(); + + let mut local_proxy = local_proxy.clone(); + local_proxy.tls = tls_config.clone(); + let _handle = tokio::spawn(async move { if let Err(err) = local_proxy::configure(&local_proxy) { error!("error while configuring local_proxy: {err:?}"); @@ -778,7 +784,7 @@ impl ComputeNode { // Spawn the extension stats background task self.spawn_extension_stats_task(); - if pspec.spec.prewarm_lfc_on_startup { + if pspec.spec.autoprewarm { self.prewarm_lfc(); } Ok(()) @@ -1204,13 +1210,15 @@ impl ComputeNode { let spec = &pspec.spec; let pgdata_path = Path::new(&self.params.pgdata); + let tls_config = self.tls_config(&pspec.spec); + // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; config::write_postgres_conf( pgdata_path, &pspec.spec, self.params.internal_http_port, - &self.compute_ctl_config.tls, + tls_config, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1306,8 +1314,8 @@ impl ComputeNode { } /// Start and stop a postgres process to warm up the VM for startup. - pub fn prewarm_postgres(&self) -> Result<()> { - info!("prewarming"); + pub fn prewarm_postgres_vm_memory(&self) -> Result<()> { + info!("prewarming VM memory"); // Create pgdata let pgdata = &format!("{}.warmup", self.params.pgdata); @@ -1349,7 +1357,7 @@ impl ComputeNode { kill(pm_pid, Signal::SIGQUIT)?; info!("sent SIGQUIT signal"); pg.wait()?; - info!("done prewarming"); + info!("done prewarming vm memory"); // clean up let _ok = fs::remove_dir_all(pgdata); @@ -1535,14 +1543,22 @@ impl ComputeNode { .clone(), ); + let mut tls_config = None::; + if spec.features.contains(&ComputeFeature::TlsExperimental) { + tls_config = self.compute_ctl_config.tls.clone(); + } + let max_concurrent_connections = self.max_service_connections(compute_state, &spec); // Merge-apply spec & changes to PostgreSQL state. self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?; if let Some(local_proxy) = &spec.clone().local_proxy_config { + let mut local_proxy = local_proxy.clone(); + local_proxy.tls = tls_config.clone(); + info!("configuring local_proxy"); - local_proxy::configure(local_proxy).context("apply_config local_proxy")?; + local_proxy::configure(&local_proxy).context("apply_config local_proxy")?; } // Run migrations separately to not hold up cold starts @@ -1594,11 +1610,13 @@ impl ComputeNode { pub fn reconfigure(&self) -> Result<()> { let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec; + let tls_config = self.tls_config(&spec); + if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings { info!("tuning pgbouncer"); let pgbouncer_settings = pgbouncer_settings.clone(); - let tls_config = self.compute_ctl_config.tls.clone(); + let tls_config = tls_config.clone(); // Spawn a background task to do the tuning, // so that we don't block the main thread that starts Postgres. @@ -1616,7 +1634,7 @@ impl ComputeNode { // Spawn a background task to do the configuration, // so that we don't block the main thread that starts Postgres. let mut local_proxy = local_proxy.clone(); - local_proxy.tls = self.compute_ctl_config.tls.clone(); + local_proxy.tls = tls_config.clone(); tokio::spawn(async move { if let Err(err) = local_proxy::configure(&local_proxy) { error!("error while configuring local_proxy: {err:?}"); @@ -1634,7 +1652,7 @@ impl ComputeNode { pgdata_path, &spec, self.params.internal_http_port, - &self.compute_ctl_config.tls, + tls_config, )?; if !spec.skip_pg_catalog_updates { @@ -1754,6 +1772,14 @@ impl ComputeNode { } } + pub fn tls_config(&self, spec: &ComputeSpec) -> &Option { + if spec.features.contains(&ComputeFeature::TlsExperimental) { + &self.compute_ctl_config.tls + } else { + &None:: + } + } + /// Update the `last_active` in the shared state, but ensure that it's a more recent one. pub fn update_last_active(&self, last_active: Option>) { let mut state = self.state.lock().unwrap(); diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3439383699..3764bc1525 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -83,6 +83,7 @@ use reqwest::StatusCode; use tar::Archive; use tracing::info; use tracing::log::warn; +use url::Url; use zstd::stream::read::Decoder; use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; @@ -158,7 +159,7 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { pub async fn download_extension( ext_name: &str, ext_path: &RemotePath, - remote_ext_base_url: &str, + remote_ext_base_url: &Url, pgbin: &str, ) -> Result { info!("Download extension {:?} from {:?}", ext_name, ext_path); @@ -270,10 +271,14 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) { } // Do request to extension storage proxy, e.g., -// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst +// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst // using HTTP GET and return the response body as bytes. -async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result { - let uri = format!("{}/{}", remote_ext_base_url, ext_path); +async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result { + let uri = remote_ext_base_url.join(ext_path).with_context(|| { + format!( + "failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}" + ) + })?; let filename = Path::new(ext_path) .file_name() .unwrap_or_else(|| std::ffi::OsStr::new("unknown")) @@ -283,7 +288,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re info!("Downloading extension file '{}' from uri {}", filename, uri); - match do_extension_server_request(&uri).await { + match do_extension_server_request(uri).await { Ok(resp) => { info!("Successfully downloaded remote extension data {}", ext_path); REMOTE_EXT_REQUESTS_TOTAL @@ -302,7 +307,7 @@ async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Re // Do a single remote extensions server request. // Return result or (error message + stringified status code) in case of any failures. -async fn do_extension_server_request(uri: &str) -> Result { +async fn do_extension_server_request(uri: Url) -> Result { let resp = reqwest::get(uri).await.map_err(|e| { ( format!( diff --git a/compute_tools/src/http/mod.rs b/compute_tools/src/http/mod.rs index 9ecc1b0093..9b01def966 100644 --- a/compute_tools/src/http/mod.rs +++ b/compute_tools/src/http/mod.rs @@ -48,11 +48,9 @@ impl JsonResponse { /// Create an error response related to the compute being in an invalid state pub(self) fn invalid_status(status: ComputeStatus) -> Response { - Self::create_response( + Self::error( StatusCode::PRECONDITION_FAILED, - &GenericAPIError { - error: format!("invalid compute status: {status}"), - }, + format!("invalid compute status: {status}"), ) } } diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index f7a19da611..c29e3a97da 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -22,7 +22,7 @@ pub(in crate::http) async fn configure( State(compute): State>, request: Json, ) -> Response { - let pspec = match ParsedSpec::try_from(request.spec.clone()) { + let pspec = match ParsedSpec::try_from(request.0.spec) { Ok(p) => p, Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e), }; diff --git a/compute_tools/src/monitor.rs b/compute_tools/src/monitor.rs index 3311ee47b3..bacaf05cd5 100644 --- a/compute_tools/src/monitor.rs +++ b/compute_tools/src/monitor.rs @@ -13,6 +13,12 @@ use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS}; const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500); +/// Struct to store runtime state of the compute monitor thread. +/// In theory, this could be a part of `Compute`, but i) +/// this state is expected to be accessed only by single thread, +/// so we don't need to care about locking; ii) `Compute` is +/// already quite big. Thus, it seems to be a good idea to keep +/// all the activity/health monitoring parts here. struct ComputeMonitor { compute: Arc, @@ -70,12 +76,36 @@ impl ComputeMonitor { ) } + /// Check if compute is in some terminal or soon-to-be-terminal + /// state, then return `true`, signalling the caller that it + /// should exit gracefully. Otherwise, return `false`. + fn check_interrupts(&mut self) -> bool { + let compute_status = self.compute.get_status(); + if matches!( + compute_status, + ComputeStatus::Terminated | ComputeStatus::TerminationPending | ComputeStatus::Failed + ) { + info!( + "compute is in {} status, stopping compute monitor", + compute_status + ); + return true; + } + + false + } + /// Spin in a loop and figure out the last activity time in the Postgres. - /// Then update it in the shared state. This function never errors out. + /// Then update it in the shared state. This function currently never + /// errors out explicitly, but there is a graceful termination path. + /// Every time we receive an error trying to check Postgres, we use + /// [`ComputeMonitor::check_interrupts()`] because it could be that + /// compute is being terminated already, then we can exit gracefully + /// to not produce errors' noise in the log. /// NB: the only expected panic is at `Mutex` unwrap(), all other errors /// should be handled gracefully. #[instrument(skip_all)] - pub fn run(&mut self) { + pub fn run(&mut self) -> anyhow::Result<()> { // Suppose that `connstr` doesn't change let connstr = self.compute.params.connstr.clone(); let conf = self @@ -93,6 +123,10 @@ impl ComputeMonitor { info!("starting compute monitor for {}", connstr); loop { + if self.check_interrupts() { + break; + } + match &mut client { Ok(cli) => { if cli.is_closed() { @@ -100,6 +134,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "connection to Postgres is closed, trying to reconnect" ); + if self.check_interrupts() { + break; + } + self.report_down(); // Connection is closed, reconnect and try again. @@ -111,15 +149,19 @@ impl ComputeMonitor { self.compute.update_last_active(self.last_active); } Err(e) => { + error!( + downtime_info = self.downtime_info(), + "could not check Postgres: {}", e + ); + if self.check_interrupts() { + break; + } + // Although we have many places where we can return errors in `check()`, // normally it shouldn't happen. I.e., we will likely return error if // connection got broken, query timed out, Postgres returned invalid data, etc. // In all such cases it's suspicious, so let's report this as downtime. self.report_down(); - error!( - downtime_info = self.downtime_info(), - "could not check Postgres: {}", e - ); // Reconnect to Postgres just in case. During tests, I noticed // that queries in `check()` can fail with `connection closed`, @@ -136,6 +178,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "could not connect to Postgres: {}, retrying", e ); + if self.check_interrupts() { + break; + } + self.report_down(); // Establish a new connection and try again. @@ -147,6 +193,9 @@ impl ComputeMonitor { self.last_checked = Utc::now(); thread::sleep(MONITOR_CHECK_INTERVAL); } + + // Graceful termination path + Ok(()) } #[instrument(skip_all)] @@ -429,7 +478,10 @@ pub fn launch_monitor(compute: &Arc) -> thread::JoinHandle<()> { .spawn(move || { let span = span!(Level::INFO, "compute_monitor"); let _enter = span.enter(); - monitor.run(); + match monitor.run() { + Ok(_) => info!("compute monitor thread terminated gracefully"), + Err(err) => error!("compute monitor thread terminated abnormally {:?}", err), + } }) .expect("cannot launch compute monitor thread") } diff --git a/compute_tools/tests/pg_helpers_tests.rs b/compute_tools/tests/pg_helpers_tests.rs index 04b6ed2256..2b865c75d0 100644 --- a/compute_tools/tests/pg_helpers_tests.rs +++ b/compute_tools/tests/pg_helpers_tests.rs @@ -30,7 +30,7 @@ mod pg_helpers_tests { r#"fsync = off wal_level = logical hot_standby = on -prewarm_lfc_on_startup = off +autoprewarm = off neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501' wal_log_hints = on log_connections = on diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 708745446d..774a0053f8 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -747,7 +747,7 @@ impl Endpoint { logs_export_host: None::, endpoint_storage_addr: Some(endpoint_storage_addr), endpoint_storage_token: Some(endpoint_storage_token), - prewarm_lfc_on_startup: false, + autoprewarm: false, }; // this strange code is needed to support respec() in tests diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 29314dab9e..0cf7ca184d 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -513,11 +513,6 @@ impl PageServerNode { .map(|x| x.parse::()) .transpose() .context("Failed to parse 'timeline_offloading' as bool")?, - wal_receiver_protocol_override: settings - .remove("wal_receiver_protocol_override") - .map(serde_json::from_str) - .transpose() - .context("parse `wal_receiver_protocol_override` from json")?, rel_size_v2_enabled: settings .remove("rel_size_v2_enabled") .map(|x| x.parse::()) diff --git a/docker-compose/compute_wrapper/Dockerfile b/docker-compose/compute_wrapper/Dockerfile index 9ef831a9cd..b89e69c650 100644 --- a/docker-compose/compute_wrapper/Dockerfile +++ b/docker-compose/compute_wrapper/Dockerfile @@ -13,6 +13,6 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \ jq \ netcat-openbsd #This is required for the pg_hintplan test -RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw +RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src/ && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src USER postgres diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index ab8d74d355..c8ca812bf9 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -1,18 +1,18 @@ -#!/bin/bash +#!/usr/bin/env bash set -eux # Generate a random tenant or timeline ID # # Takes a variable name as argument. The result is stored in that variable. generate_id() { - local -n resvar=$1 - printf -v resvar '%08x%08x%08x%08x' $SRANDOM $SRANDOM $SRANDOM $SRANDOM + local -n resvar=${1} + printf -v resvar '%08x%08x%08x%08x' ${SRANDOM} ${SRANDOM} ${SRANDOM} ${SRANDOM} } PG_VERSION=${PG_VERSION:-14} -CONFIG_FILE_ORG=/var/db/postgres/configs/config.json -CONFIG_FILE=/tmp/config.json +readonly CONFIG_FILE_ORG=/var/db/postgres/configs/config.json +readonly CONFIG_FILE=/tmp/config.json # Test that the first library path that the dynamic loader looks in is the path # that we use for custom compiled software @@ -20,17 +20,17 @@ first_path="$(ldconfig --verbose 2>/dev/null \ | grep --invert-match ^$'\t' \ | cut --delimiter=: --fields=1 \ | head --lines=1)" -test "$first_path" == '/usr/local/lib' +test "${first_path}" = '/usr/local/lib' echo "Waiting pageserver become ready." while ! nc -z pageserver 6400; do - sleep 1; + sleep 1 done echo "Page server is ready." -cp ${CONFIG_FILE_ORG} ${CONFIG_FILE} +cp "${CONFIG_FILE_ORG}" "${CONFIG_FILE}" - if [ -n "${TENANT_ID:-}" ] && [ -n "${TIMELINE_ID:-}" ]; then + if [[ -n "${TENANT_ID:-}" && -n "${TIMELINE_ID:-}" ]]; then tenant_id=${TENANT_ID} timeline_id=${TIMELINE_ID} else @@ -41,7 +41,7 @@ else "http://pageserver:9898/v1/tenant" ) tenant_id=$(curl "${PARAMS[@]}" | jq -r .[0].id) - if [ -z "${tenant_id}" ] || [ "${tenant_id}" = null ]; then + if [[ -z "${tenant_id}" || "${tenant_id}" = null ]]; then echo "Create a tenant" generate_id tenant_id PARAMS=( @@ -51,7 +51,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/location_config" ) result=$(curl "${PARAMS[@]}") - echo $result | jq . + printf '%s\n' "${result}" | jq . fi echo "Check if a timeline present" @@ -61,7 +61,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/timeline" ) timeline_id=$(curl "${PARAMS[@]}" | jq -r .[0].timeline_id) - if [ -z "${timeline_id}" ] || [ "${timeline_id}" = null ]; then + if [[ -z "${timeline_id}" || "${timeline_id}" = null ]]; then generate_id timeline_id PARAMS=( -sbf @@ -71,7 +71,7 @@ else "http://pageserver:9898/v1/tenant/${tenant_id}/timeline/" ) result=$(curl "${PARAMS[@]}") - echo $result | jq . + printf '%s\n' "${result}" | jq . fi fi @@ -82,10 +82,10 @@ else fi echo "Adding pgx_ulid" shared_libraries=$(jq -r '.spec.cluster.settings[] | select(.name=="shared_preload_libraries").value' ${CONFIG_FILE}) -sed -i "s/${shared_libraries}/${shared_libraries},${ulid_extension}/" ${CONFIG_FILE} +sed -i "s|${shared_libraries}|${shared_libraries},${ulid_extension}|" ${CONFIG_FILE} echo "Overwrite tenant id and timeline id in spec file" -sed -i "s/TENANT_ID/${tenant_id}/" ${CONFIG_FILE} -sed -i "s/TIMELINE_ID/${timeline_id}/" ${CONFIG_FILE} +sed -i "s|TENANT_ID|${tenant_id}|" ${CONFIG_FILE} +sed -i "s|TIMELINE_ID|${timeline_id}|" ${CONFIG_FILE} cat ${CONFIG_FILE} @@ -93,5 +93,5 @@ echo "Start compute node" /usr/local/bin/compute_ctl --pgdata /var/db/postgres/compute \ -C "postgresql://cloud_admin@localhost:55433/postgres" \ -b /usr/local/bin/postgres \ - --compute-id "compute-$RANDOM" \ - --config "$CONFIG_FILE" + --compute-id "compute-${RANDOM}" \ + --config "${CONFIG_FILE}" diff --git a/docker-compose/docker-compose.yml b/docker-compose/docker-compose.yml index fd3ad1fffc..2519b75c7f 100644 --- a/docker-compose/docker-compose.yml +++ b/docker-compose/docker-compose.yml @@ -186,13 +186,14 @@ services: neon-test-extensions: profiles: ["test-extensions"] - image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-16}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}} + image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-${PG_VERSION:-16}}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}} environment: - - PGPASSWORD=cloud_admin + - PGUSER=${PGUSER:-cloud_admin} + - PGPASSWORD=${PGPASSWORD:-cloud_admin} entrypoint: - "/bin/bash" - "-c" command: - - sleep 1800 + - sleep 3600 depends_on: - compute diff --git a/docker-compose/docker_compose_test.sh b/docker-compose/docker_compose_test.sh index 2645a49883..6edf90ca8d 100755 --- a/docker-compose/docker_compose_test.sh +++ b/docker-compose/docker_compose_test.sh @@ -54,6 +54,15 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do # It cannot be moved to Dockerfile now because the database directory is created after the start of the container echo Adding dummy config docker compose exec compute touch /var/db/postgres/compute/compute_ctl_temp_override.conf + # Prepare for the PostGIS test + docker compose exec compute mkdir -p /tmp/pgis_reg/pgis_reg_tmp + TMPDIR=$(mktemp -d) + docker compose cp neon-test-extensions:/ext-src/postgis-src/raster/test "${TMPDIR}" + docker compose cp neon-test-extensions:/ext-src/postgis-src/regress/00-regress-install "${TMPDIR}" + docker compose exec compute mkdir -p /ext-src/postgis-src/raster /ext-src/postgis-src/regress /ext-src/postgis-src/regress/00-regress-install + docker compose cp "${TMPDIR}/test" compute:/ext-src/postgis-src/raster/test + docker compose cp "${TMPDIR}/00-regress-install" compute:/ext-src/postgis-src/regress + rm -rf "${TMPDIR}" # The following block copies the files for the pg_hintplan test to the compute node for the extension test in an isolated docker-compose environment TMPDIR=$(mktemp -d) docker compose cp neon-test-extensions:/ext-src/pg_hint_plan-src/data "${TMPDIR}/data" @@ -68,7 +77,7 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do docker compose exec -T neon-test-extensions bash -c "(cd /postgres && patch -p1)" <"../compute/patches/contrib_pg${pg_version}.patch" # We are running tests now rm -f testout.txt testout_contrib.txt - docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,postgis-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \ + docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \ neon-test-extensions /run-tests.sh /ext-src | tee testout.txt && EXT_SUCCESS=1 || EXT_SUCCESS=0 docker compose exec -e SKIP=start-scripts,postgres_fdw,ltree_plpython,jsonb_plpython,jsonb_plperl,hstore_plpython,hstore_plperl,dblink,bool_plperl \ neon-test-extensions /run-tests.sh /postgres/contrib | tee testout_contrib.txt && CONTRIB_SUCCESS=1 || CONTRIB_SUCCESS=0 diff --git a/docker-compose/ext-src/postgis-src/README-Neon.md b/docker-compose/ext-src/postgis-src/README-Neon.md new file mode 100644 index 0000000000..5937fc782b --- /dev/null +++ b/docker-compose/ext-src/postgis-src/README-Neon.md @@ -0,0 +1,70 @@ +# PostGIS Testing in Neon + +This directory contains configuration files and patches for running PostGIS tests in the Neon database environment. + +## Overview + +PostGIS is a spatial database extension for PostgreSQL that adds support for geographic objects. Testing PostGIS compatibility ensures that Neon's modifications to PostgreSQL don't break compatibility with this critical extension. + +## PostGIS Versions + +- PostgreSQL v17: PostGIS 3.5.0 +- PostgreSQL v14/v15/v16: PostGIS 3.3.3 + +## Test Configuration + +The test setup includes: + +- `postgis-no-upgrade-test.patch`: Disables upgrade tests by removing the upgrade test section from regress/runtest.mk +- `postgis-regular-v16.patch`: Version-specific patch for PostgreSQL v16 +- `postgis-regular-v17.patch`: Version-specific patch for PostgreSQL v17 +- `regular-test.sh`: Script to run PostGIS tests as a regular user +- `neon-test.sh`: Script to handle version-specific test configurations +- `raster_outdb_template.sql`: Template for raster tests with explicit file paths + +## Excluded Tests + +**Important Note:** The test exclusions listed below are specifically for regular-user tests against staging instances. These exclusions are necessary because staging instances run with limited privileges and cannot perform operations requiring superuser access. Docker-compose based tests are not affected by these exclusions. + +### Tests Requiring Superuser Permissions + +These tests cannot be run as a regular user: +- `estimatedextent` +- `regress/core/legacy` +- `regress/core/typmod` +- `regress/loader/TestSkipANALYZE` +- `regress/loader/TestANALYZE` + +### Tests Requiring Filesystem Access + +These tests need direct filesystem access that is only possible for superusers: +- `loader/load_outdb` + +### Tests with Flaky Results + +These tests have assumptions that don't always hold true: +- `regress/core/computed_columns` - Assumes computed columns always outperform alternatives, which is not consistently true + +### Tests Requiring Tunable Parameter Modifications + +These tests attempt to modify the `postgis.gdal_enabled_drivers` parameter, which is only accessible to superusers: +- `raster/test/regress/rt_wkb` +- `raster/test/regress/rt_addband` +- `raster/test/regress/rt_setbandpath` +- `raster/test/regress/rt_fromgdalraster` +- `raster/test/regress/rt_asgdalraster` +- `raster/test/regress/rt_astiff` +- `raster/test/regress/rt_asjpeg` +- `raster/test/regress/rt_aspng` +- `raster/test/regress/permitted_gdal_drivers` +- Loader tests: `BasicOutDB`, `Tiled10x10`, `Tiled10x10Copy`, `Tiled8x8`, `TiledAuto`, `TiledAutoSkipNoData`, `TiledAutoCopyn` + +### Topology Tests (v17 only) +- `populate_topology_layer` +- `renametopogeometrycolumn` + +## Other Modifications + +- Binary.sql tests are modified to use explicit file paths +- Server-side SQL COPY commands (which require superuser privileges) are converted to client-side `\copy` commands +- Upgrade tests are disabled diff --git a/docker-compose/ext-src/postgis-src/neon-test.sh b/docker-compose/ext-src/postgis-src/neon-test.sh new file mode 100755 index 0000000000..2866649a1b --- /dev/null +++ b/docker-compose/ext-src/postgis-src/neon-test.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -ex +cd "$(dirname "$0")" +if [[ ${PG_VERSION} = v17 ]]; then + sed -i '/computed_columns/d' regress/core/tests.mk +fi +patch -p1 =" 120),1) +- TESTS += \ +- $(top_srcdir)/regress/core/computed_columns +-endif +- + ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) + # GEOS-3.7 adds: + # ST_FrechetDistance +diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk +index 1fc77ac..c3cb9de 100644 +--- a/regress/loader/tests.mk ++++ b/regress/loader/tests.mk +@@ -38,7 +38,5 @@ TESTS += \ + $(top_srcdir)/regress/loader/Latin1 \ + $(top_srcdir)/regress/loader/Latin1-implicit \ + $(top_srcdir)/regress/loader/mfile \ +- $(top_srcdir)/regress/loader/TestSkipANALYZE \ +- $(top_srcdir)/regress/loader/TestANALYZE \ + $(top_srcdir)/regress/loader/CharNoWidth + +diff --git a/regress/run_test.pl b/regress/run_test.pl +index 0ec5b2d..1c331f4 100755 +--- a/regress/run_test.pl ++++ b/regress/run_test.pl +@@ -147,7 +147,6 @@ $ENV{"LANG"} = "C"; + # Add locale info to the psql options + # Add pg12 precision suppression + my $PGOPTIONS = $ENV{"PGOPTIONS"}; +-$PGOPTIONS .= " -c lc_messages=C"; + $PGOPTIONS .= " -c client_min_messages=NOTICE"; + $PGOPTIONS .= " -c extra_float_digits=0"; + $ENV{"PGOPTIONS"} = $PGOPTIONS; diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch new file mode 100644 index 0000000000..f4a9d83478 --- /dev/null +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch @@ -0,0 +1,218 @@ +diff --git a/raster/test/regress/tests.mk b/raster/test/regress/tests.mk +index 00918e1..7e2b6cd 100644 +--- a/raster/test/regress/tests.mk ++++ b/raster/test/regress/tests.mk +@@ -17,9 +17,7 @@ override RUNTESTFLAGS_INTERNAL := \ + $(RUNTESTFLAGS_INTERNAL) \ + --after-upgrade-script $(top_srcdir)/raster/test/regress/hooks/hook-after-upgrade-raster.sql + +-RASTER_TEST_FIRST = \ +- $(top_srcdir)/raster/test/regress/check_gdal \ +- $(top_srcdir)/raster/test/regress/loader/load_outdb ++RASTER_TEST_FIRST = + + RASTER_TEST_LAST = \ + $(top_srcdir)/raster/test/regress/clean +@@ -33,9 +31,7 @@ RASTER_TEST_IO = \ + + RASTER_TEST_BASIC_FUNC = \ + $(top_srcdir)/raster/test/regress/rt_bytea \ +- $(top_srcdir)/raster/test/regress/rt_wkb \ + $(top_srcdir)/raster/test/regress/box3d \ +- $(top_srcdir)/raster/test/regress/rt_addband \ + $(top_srcdir)/raster/test/regress/rt_band \ + $(top_srcdir)/raster/test/regress/rt_tile + +@@ -73,16 +69,10 @@ RASTER_TEST_BANDPROPS = \ + $(top_srcdir)/raster/test/regress/rt_neighborhood \ + $(top_srcdir)/raster/test/regress/rt_nearestvalue \ + $(top_srcdir)/raster/test/regress/rt_pixelofvalue \ +- $(top_srcdir)/raster/test/regress/rt_polygon \ +- $(top_srcdir)/raster/test/regress/rt_setbandpath ++ $(top_srcdir)/raster/test/regress/rt_polygon + + RASTER_TEST_UTILITY = \ + $(top_srcdir)/raster/test/regress/rt_utility \ +- $(top_srcdir)/raster/test/regress/rt_fromgdalraster \ +- $(top_srcdir)/raster/test/regress/rt_asgdalraster \ +- $(top_srcdir)/raster/test/regress/rt_astiff \ +- $(top_srcdir)/raster/test/regress/rt_asjpeg \ +- $(top_srcdir)/raster/test/regress/rt_aspng \ + $(top_srcdir)/raster/test/regress/rt_reclass \ + $(top_srcdir)/raster/test/regress/rt_gdalwarp \ + $(top_srcdir)/raster/test/regress/rt_gdalcontour \ +@@ -120,21 +110,13 @@ RASTER_TEST_SREL = \ + + RASTER_TEST_BUGS = \ + $(top_srcdir)/raster/test/regress/bug_test_car5 \ +- $(top_srcdir)/raster/test/regress/permitted_gdal_drivers \ + $(top_srcdir)/raster/test/regress/tickets + + RASTER_TEST_LOADER = \ + $(top_srcdir)/raster/test/regress/loader/Basic \ + $(top_srcdir)/raster/test/regress/loader/Projected \ + $(top_srcdir)/raster/test/regress/loader/BasicCopy \ +- $(top_srcdir)/raster/test/regress/loader/BasicFilename \ +- $(top_srcdir)/raster/test/regress/loader/BasicOutDB \ +- $(top_srcdir)/raster/test/regress/loader/Tiled10x10 \ +- $(top_srcdir)/raster/test/regress/loader/Tiled10x10Copy \ +- $(top_srcdir)/raster/test/regress/loader/Tiled8x8 \ +- $(top_srcdir)/raster/test/regress/loader/TiledAuto \ +- $(top_srcdir)/raster/test/regress/loader/TiledAutoSkipNoData \ +- $(top_srcdir)/raster/test/regress/loader/TiledAutoCopyn ++ $(top_srcdir)/raster/test/regress/loader/BasicFilename + + RASTER_TESTS := $(RASTER_TEST_FIRST) \ + $(RASTER_TEST_METADATA) $(RASTER_TEST_IO) $(RASTER_TEST_BASIC_FUNC) \ +diff --git a/regress/core/binary.sql b/regress/core/binary.sql +index 7a36b65..ad78fc7 100644 +--- a/regress/core/binary.sql ++++ b/regress/core/binary.sql +@@ -1,4 +1,5 @@ + SET client_min_messages TO warning; ++ + CREATE SCHEMA tm; + + CREATE TABLE tm.geoms (id serial, g geometry); +@@ -31,24 +32,39 @@ SELECT st_force4d(g) FROM tm.geoms WHERE id < 15 ORDER BY id; + INSERT INTO tm.geoms(g) + SELECT st_setsrid(g,4326) FROM tm.geoms ORDER BY id; + +-COPY tm.geoms TO :tmpfile WITH BINARY; ++-- define temp file path ++\set tmpfile '/tmp/postgis_binary_test.dat' ++ ++-- export ++\set command '\\copy tm.geoms TO ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++-- import + CREATE TABLE tm.geoms_in AS SELECT * FROM tm.geoms LIMIT 0; +-COPY tm.geoms_in FROM :tmpfile WITH BINARY; +-SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o WHERE i.id = o.id +- AND ST_OrderingEquals(i.g, o.g); ++\set command '\\copy tm.geoms_in FROM ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o ++WHERE i.id = o.id AND ST_OrderingEquals(i.g, o.g); + + CREATE TABLE tm.geogs AS SELECT id,g::geography FROM tm.geoms + WHERE geometrytype(g) NOT LIKE '%CURVE%' + AND geometrytype(g) NOT LIKE '%CIRCULAR%' + AND geometrytype(g) NOT LIKE '%SURFACE%' + AND geometrytype(g) NOT LIKE 'TRIANGLE%' +- AND geometrytype(g) NOT LIKE 'TIN%' +-; ++ AND geometrytype(g) NOT LIKE 'TIN%'; + +-COPY tm.geogs TO :tmpfile WITH BINARY; ++-- export ++\set command '\\copy tm.geogs TO ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++-- import + CREATE TABLE tm.geogs_in AS SELECT * FROM tm.geogs LIMIT 0; +-COPY tm.geogs_in FROM :tmpfile WITH BINARY; +-SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o WHERE i.id = o.id +- AND ST_OrderingEquals(i.g::geometry, o.g::geometry); ++\set command '\\copy tm.geogs_in FROM ':tmpfile' WITH (FORMAT BINARY)' ++:command ++ ++SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o ++WHERE i.id = o.id AND ST_OrderingEquals(i.g::geometry, o.g::geometry); + + DROP SCHEMA tm CASCADE; ++ +diff --git a/regress/core/tests.mk b/regress/core/tests.mk +index 9e05244..a63a3e1 100644 +--- a/regress/core/tests.mk ++++ b/regress/core/tests.mk +@@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170 + POSTGIS_GEOS_VERSION=31101 + HAVE_JSON=yes + HAVE_SPGIST=yes +-INTERRUPTTESTS=yes ++INTERRUPTTESTS=no + + current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) + + RUNTESTFLAGS_INTERNAL += \ + --before-upgrade-script $(top_srcdir)/regress/hooks/hook-before-upgrade.sql \ + --after-upgrade-script $(top_srcdir)/regress/hooks/hook-after-upgrade.sql \ +- --after-create-script $(top_srcdir)/regress/hooks/hook-after-create.sql \ + --before-uninstall-script $(top_srcdir)/regress/hooks/hook-before-uninstall.sql + + TESTS += \ +@@ -40,7 +39,6 @@ TESTS += \ + $(top_srcdir)/regress/core/dumppoints \ + $(top_srcdir)/regress/core/dumpsegments \ + $(top_srcdir)/regress/core/empty \ +- $(top_srcdir)/regress/core/estimatedextent \ + $(top_srcdir)/regress/core/forcecurve \ + $(top_srcdir)/regress/core/flatgeobuf \ + $(top_srcdir)/regress/core/frechet \ +@@ -60,7 +58,6 @@ TESTS += \ + $(top_srcdir)/regress/core/out_marc21 \ + $(top_srcdir)/regress/core/in_encodedpolyline \ + $(top_srcdir)/regress/core/iscollection \ +- $(top_srcdir)/regress/core/legacy \ + $(top_srcdir)/regress/core/letters \ + $(top_srcdir)/regress/core/lwgeom_regress \ + $(top_srcdir)/regress/core/measures \ +@@ -119,7 +116,6 @@ TESTS += \ + $(top_srcdir)/regress/core/temporal_knn \ + $(top_srcdir)/regress/core/tickets \ + $(top_srcdir)/regress/core/twkb \ +- $(top_srcdir)/regress/core/typmod \ + $(top_srcdir)/regress/core/wkb \ + $(top_srcdir)/regress/core/wkt \ + $(top_srcdir)/regress/core/wmsservers \ +@@ -143,8 +139,7 @@ TESTS += \ + $(top_srcdir)/regress/core/oriented_envelope \ + $(top_srcdir)/regress/core/point_coordinates \ + $(top_srcdir)/regress/core/out_geojson \ +- $(top_srcdir)/regress/core/wrapx \ +- $(top_srcdir)/regress/core/computed_columns ++ $(top_srcdir)/regress/core/wrapx + + # Slow slow tests + TESTS_SLOW = \ +diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk +index ac4f8ad..4bad4fc 100644 +--- a/regress/loader/tests.mk ++++ b/regress/loader/tests.mk +@@ -38,7 +38,5 @@ TESTS += \ + $(top_srcdir)/regress/loader/Latin1 \ + $(top_srcdir)/regress/loader/Latin1-implicit \ + $(top_srcdir)/regress/loader/mfile \ +- $(top_srcdir)/regress/loader/TestSkipANALYZE \ +- $(top_srcdir)/regress/loader/TestANALYZE \ + $(top_srcdir)/regress/loader/CharNoWidth \ + +diff --git a/regress/run_test.pl b/regress/run_test.pl +index cac4b2e..4c7c82b 100755 +--- a/regress/run_test.pl ++++ b/regress/run_test.pl +@@ -238,7 +238,6 @@ $ENV{"LANG"} = "C"; + # Add locale info to the psql options + # Add pg12 precision suppression + my $PGOPTIONS = $ENV{"PGOPTIONS"}; +-$PGOPTIONS .= " -c lc_messages=C"; + $PGOPTIONS .= " -c client_min_messages=NOTICE"; + $PGOPTIONS .= " -c extra_float_digits=0"; + $ENV{"PGOPTIONS"} = $PGOPTIONS; +diff --git a/topology/test/tests.mk b/topology/test/tests.mk +index cbe2633..2c7c18f 100644 +--- a/topology/test/tests.mk ++++ b/topology/test/tests.mk +@@ -46,9 +46,7 @@ TESTS += \ + $(top_srcdir)/topology/test/regress/legacy_query.sql \ + $(top_srcdir)/topology/test/regress/legacy_validate.sql \ + $(top_srcdir)/topology/test/regress/polygonize.sql \ +- $(top_srcdir)/topology/test/regress/populate_topology_layer.sql \ + $(top_srcdir)/topology/test/regress/removeunusedprimitives.sql \ +- $(top_srcdir)/topology/test/regress/renametopogeometrycolumn.sql \ + $(top_srcdir)/topology/test/regress/renametopology.sql \ + $(top_srcdir)/topology/test/regress/share_sequences.sql \ + $(top_srcdir)/topology/test/regress/sqlmm.sql \ diff --git a/docker-compose/ext-src/postgis-src/raster_outdb_template.sql b/docker-compose/ext-src/postgis-src/raster_outdb_template.sql new file mode 100644 index 0000000000..16232f28dd --- /dev/null +++ b/docker-compose/ext-src/postgis-src/raster_outdb_template.sql @@ -0,0 +1,46 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 17.4 +-- Dumped by pg_dump version 17.4 + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET transaction_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; + +-- +-- Name: raster_outdb_template; Type: TABLE; Schema: public; Owner: cloud_admin +-- + +CREATE TABLE public.raster_outdb_template ( + rid integer, + rast public.raster +); + + +ALTER TABLE public.raster_outdb_template OWNER TO cloud_admin; + +-- +-- Data for Name: raster_outdb_template; Type: TABLE DATA; Schema: public; Owner: cloud_admin +-- + +COPY public.raster_outdb_template (rid, rast) FROM stdin; +1 0100000300000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A0032008400002F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400022F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +2 0100000300000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A0032008400002F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966008400022F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +3 0100000200000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A00320044000101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101018400012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E74696600 +4 0100000200000000000000F03F000000000000F0BF0000000000000000000000000000000000000000000000000000000000000000000000005A003200C4FF012F6578742D7372632F706F73746769732D7372632F726567726573732F2E2E2F7261737465722F746573742F726567726573732F6C6F616465722F746573747261737465722E746966004400010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101 +\. + + +-- +-- PostgreSQL database dump complete +-- + diff --git a/docker-compose/ext-src/postgis-src/regular-test.sh b/docker-compose/ext-src/postgis-src/regular-test.sh new file mode 100755 index 0000000000..4b0b929946 --- /dev/null +++ b/docker-compose/ext-src/postgis-src/regular-test.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -ex +cd "$(dirname "${0}")" +dropdb --if-exist contrib_regression +createdb contrib_regression +psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='UTC'" \ + -c "ALTER DATABASE contrib_regression SET DateStyle='ISO, MDY'" \ + -c "CREATE EXTENSION postgis SCHEMA public" \ + -c "CREATE EXTENSION postgis_topology" \ + -c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \ + -c "CREATE EXTENSION postgis_raster SCHEMA public" \ + -c "CREATE EXTENSION postgis_sfcgal SCHEMA public" +patch -p1 , - /// If true, download LFC state from endpoint_storage and pass it to Postgres on startup + /// Download LFC state from endpoint_storage and pass it to Postgres on startup #[serde(default)] - pub prewarm_lfc_on_startup: bool, + pub autoprewarm: bool, } /// Feature flag to signal `compute_ctl` to enable certain experimental functionality. @@ -192,6 +192,9 @@ pub enum ComputeFeature { /// track short-lived connections as user activity. ActivityMonitorExperimental, + /// Enable TLS functionality. + TlsExperimental, + /// This is a special feature flag that is used to represent unknown feature flags. /// Basically all unknown to enum flags are represented as this one. See unit test /// `parse_unknown_features()` for more details. @@ -250,34 +253,44 @@ impl RemoteExtSpec { } match self.extension_data.get(real_ext_name) { - Some(_ext_data) => { - // We have decided to use the Go naming convention due to Kubernetes. - - let arch = match std::env::consts::ARCH { - "x86_64" => "amd64", - "aarch64" => "arm64", - arch => arch, - }; - - // Construct the path to the extension archive - // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst - // - // Keep it in sync with path generation in - // https://github.com/neondatabase/build-custom-extensions/tree/main - let archive_path_str = format!( - "{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst" - ); - Ok(( - real_ext_name.to_string(), - RemotePath::from_string(&archive_path_str)?, - )) - } + Some(_ext_data) => Ok(( + real_ext_name.to_string(), + Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?, + )), None => Err(anyhow::anyhow!( "real_ext_name {} is not found", real_ext_name )), } } + + /// Get the architecture-specific portion of the remote extension path. We + /// use the Go naming convention due to Kubernetes. + fn get_arch() -> &'static str { + match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + arch => arch, + } + } + + /// Build a [`RemotePath`] for an extension. + fn build_remote_path( + build_tag: &str, + pg_major_version: &str, + ext_name: &str, + ) -> anyhow::Result { + let arch = Self::get_arch(); + + // Construct the path to the extension archive + // BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst + // + // Keep it in sync with path generation in + // https://github.com/neondatabase/build-custom-extensions/tree/main + RemotePath::from_string(&format!( + "{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst" + )) + } } #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)] @@ -518,6 +531,37 @@ mod tests { .expect("Library should be found"); } + #[test] + fn remote_extension_path() { + let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({ + "public_extensions": ["ext"], + "custom_extensions": [], + "library_index": { + "extlib": "ext", + }, + "extension_data": { + "ext": { + "control_data": { + "ext.control": "" + }, + "archive_path": "" + } + }, + })) + .unwrap(); + + let (_ext_name, ext_path) = rspec + .get_ext("ext", false, "latest", "v17") + .expect("Extension should be found"); + // Starting with a forward slash would have consequences for the + // Url::join() that occurs when downloading a remote extension. + assert!(!ext_path.to_string().starts_with("/")); + assert_eq!( + ext_path, + RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap() + ); + } + #[test] fn parse_spec_file() { let file = File::open("tests/cluster_spec.json").unwrap(); diff --git a/libs/compute_api/tests/cluster_spec.json b/libs/compute_api/tests/cluster_spec.json index 30e788a601..2dd2aae015 100644 --- a/libs/compute_api/tests/cluster_spec.json +++ b/libs/compute_api/tests/cluster_spec.json @@ -85,7 +85,7 @@ "vartype": "bool" }, { - "name": "prewarm_lfc_on_startup", + "name": "autoprewarm", "value": "off", "vartype": "bool" }, diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 4df8d7bc51..5d028ee041 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -27,6 +27,7 @@ pub use prometheus::{ pub mod launch_timestamp; mod wrappers; +pub use prometheus; pub use wrappers::{CountedReader, CountedWriter}; mod hll; pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec}; diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 012c020fb1..30b0612082 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -20,7 +20,6 @@ use postgres_backend::AuthType; use remote_storage::RemoteStorageConfig; use serde_with::serde_as; use utils::logging::LogFormat; -use utils::postgres_client::PostgresClientProtocol; use crate::models::{ImageCompressionAlgorithm, LsnLease}; @@ -181,6 +180,7 @@ pub struct ConfigToml { pub virtual_file_io_engine: Option, pub ingest_batch_size: u64, pub max_vectored_read_bytes: MaxVectoredReadBytes, + pub max_get_vectored_keys: MaxGetVectoredKeys, pub image_compression: ImageCompressionAlgorithm, pub timeline_offloading: bool, pub ephemeral_bytes_per_memory_kb: usize, @@ -188,7 +188,6 @@ pub struct ConfigToml { pub virtual_file_io_mode: Option, #[serde(skip_serializing_if = "Option::is_none")] pub no_sync: Option, - pub wal_receiver_protocol: PostgresClientProtocol, pub page_service_pipelining: PageServicePipeliningConfig, pub get_vectored_concurrent_io: GetVectoredConcurrentIo, pub enable_read_path_debugging: Option, @@ -229,7 +228,7 @@ pub enum PageServicePipeliningConfig { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct PageServicePipeliningConfigPipelined { - /// Causes runtime errors if larger than max get_vectored batch size. + /// Failed config parsing and validation if larger than `max_get_vectored_keys`. pub max_batch_size: NonZeroUsize, pub execution: PageServiceProtocolPipelinedExecutionStrategy, // The default below is such that new versions of the software can start @@ -329,6 +328,8 @@ pub struct TimelineImportConfig { pub import_job_concurrency: NonZeroUsize, pub import_job_soft_size_limit: NonZeroUsize, pub import_job_checkpoint_threshold: NonZeroUsize, + /// Max size of the remote storage partial read done by any job + pub import_job_max_byte_range_size: NonZeroUsize, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -403,6 +404,16 @@ impl Default for EvictionOrder { #[serde(transparent)] pub struct MaxVectoredReadBytes(pub NonZeroUsize); +#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct MaxGetVectoredKeys(NonZeroUsize); + +impl MaxGetVectoredKeys { + pub fn get(&self) -> usize { + self.0.get() + } +} + /// Tenant-level configuration values, used for various purposes. #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(default)] @@ -514,8 +525,6 @@ pub struct TenantConfigToml { /// (either this flag or the pageserver-global one need to be set) pub timeline_offloading: bool, - pub wal_receiver_protocol_override: Option, - /// Enable rel_size_v2 for this tenant. Once enabled, the tenant will persist this information into /// `index_part.json`, and it cannot be reversed. pub rel_size_v2_enabled: bool, @@ -587,6 +596,8 @@ pub mod defaults { /// That is, slightly above 128 kB. pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB + pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32; + pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm = ImageCompressionAlgorithm::Zstd { level: Some(1) }; @@ -594,9 +605,6 @@ pub mod defaults { pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; - pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = - utils::postgres_client::PostgresClientProtocol::Vanilla; - pub const DEFAULT_SSL_KEY_FILE: &str = "server.key"; pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt"; } @@ -685,6 +693,9 @@ impl Default for ConfigToml { max_vectored_read_bytes: (MaxVectoredReadBytes( NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(), )), + max_get_vectored_keys: (MaxGetVectoredKeys( + NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(), + )), image_compression: (DEFAULT_IMAGE_COMPRESSION), timeline_offloading: true, ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), @@ -692,7 +703,6 @@ impl Default for ConfigToml { virtual_file_io_mode: None, tenant_config: TenantConfigToml::default(), no_sync: None, - wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, page_service_pipelining: PageServicePipeliningConfig::Pipelined( PageServicePipeliningConfigPipelined { max_batch_size: NonZeroUsize::new(32).unwrap(), @@ -713,9 +723,10 @@ impl Default for ConfigToml { enable_tls_page_service_api: false, dev_mode: false, timeline_import_config: TimelineImportConfig { - import_job_concurrency: NonZeroUsize::new(128).unwrap(), - import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(), - import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), + import_job_concurrency: NonZeroUsize::new(32).unwrap(), + import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), + import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), + import_job_max_byte_range_size: NonZeroUsize::new(4 * 1024 * 1024).unwrap(), }, basebackup_cache_config: None, posthog_config: None, @@ -836,7 +847,6 @@ impl Default for TenantConfigToml { lsn_lease_length: LsnLease::DEFAULT_LENGTH, lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS, timeline_offloading: true, - wal_receiver_protocol_override: None, rel_size_v2_enabled: false, gc_compaction_enabled: DEFAULT_GC_COMPACTION_ENABLED, gc_compaction_verification: DEFAULT_GC_COMPACTION_VERIFICATION, diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index e7d612bb7a..881f24b86c 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -20,7 +20,6 @@ use serde_with::serde_as; pub use utilization::PageserverUtilization; use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::postgres_client::PostgresClientProtocol; use utils::{completion, serde_system_time}; use crate::config::Ratio; @@ -622,8 +621,6 @@ pub struct TenantConfigPatch { #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub timeline_offloading: FieldPatch, #[serde(skip_serializing_if = "FieldPatch::is_noop")] - pub wal_receiver_protocol_override: FieldPatch, - #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub rel_size_v2_enabled: FieldPatch, #[serde(skip_serializing_if = "FieldPatch::is_noop")] pub gc_compaction_enabled: FieldPatch, @@ -748,9 +745,6 @@ pub struct TenantConfig { #[serde(skip_serializing_if = "Option::is_none")] pub timeline_offloading: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub wal_receiver_protocol_override: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub rel_size_v2_enabled: Option, @@ -812,7 +806,6 @@ impl TenantConfig { mut lsn_lease_length, mut lsn_lease_length_for_ts, mut timeline_offloading, - mut wal_receiver_protocol_override, mut rel_size_v2_enabled, mut gc_compaction_enabled, mut gc_compaction_verification, @@ -905,9 +898,6 @@ impl TenantConfig { .map(|v| humantime::parse_duration(&v))? .apply(&mut lsn_lease_length_for_ts); patch.timeline_offloading.apply(&mut timeline_offloading); - patch - .wal_receiver_protocol_override - .apply(&mut wal_receiver_protocol_override); patch.rel_size_v2_enabled.apply(&mut rel_size_v2_enabled); patch .gc_compaction_enabled @@ -960,7 +950,6 @@ impl TenantConfig { lsn_lease_length, lsn_lease_length_for_ts, timeline_offloading, - wal_receiver_protocol_override, rel_size_v2_enabled, gc_compaction_enabled, gc_compaction_verification, @@ -1058,9 +1047,6 @@ impl TenantConfig { timeline_offloading: self .timeline_offloading .unwrap_or(global_conf.timeline_offloading), - wal_receiver_protocol_override: self - .wal_receiver_protocol_override - .or(global_conf.wal_receiver_protocol_override), rel_size_v2_enabled: self .rel_size_v2_enabled .unwrap_or(global_conf.rel_size_v2_enabled), @@ -1934,7 +1920,7 @@ pub enum PagestreamFeMessage { } // Wrapped in libpq CopyData -#[derive(strum_macros::EnumProperty)] +#[derive(Debug, strum_macros::EnumProperty)] pub enum PagestreamBeMessage { Exists(PagestreamExistsResponse), Nblocks(PagestreamNblocksResponse), @@ -2045,7 +2031,7 @@ pub enum PagestreamProtocolVersion { pub type RequestId = u64; -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamRequest { pub reqid: RequestId, pub request_lsn: Lsn, @@ -2064,7 +2050,7 @@ pub struct PagestreamNblocksRequest { pub rel: RelTag, } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamGetPageRequest { pub hdr: PagestreamRequest, pub rel: RelTag, diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 473a44dbf9..e0dd4fdfe8 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; // FIXME: should move 'forknum' as last field to keep this consistent with Postgres. // Then we could replace the custom Ord and PartialOrd implementations below with // deriving them. This will require changes in walredoproc.c. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] pub struct RelTag { pub forknum: u8, pub spcnode: Oid, @@ -184,12 +184,12 @@ pub enum SlruKind { MultiXactOffsets, } -impl SlruKind { - pub fn to_str(&self) -> &'static str { +impl fmt::Display for SlruKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Clog => "pg_xact", - Self::MultiXactMembers => "pg_multixact/members", - Self::MultiXactOffsets => "pg_multixact/offsets", + Self::Clog => write!(f, "pg_xact"), + Self::MultiXactMembers => write!(f, "pg_multixact/members"), + Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"), } } } diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs index a05f6096b1..a404c76da9 100644 --- a/libs/posthog_client_lite/src/background_loop.rs +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -6,7 +6,7 @@ use arc_swap::ArcSwap; use tokio_util::sync::CancellationToken; use tracing::{Instrument, info_span}; -use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; +use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig}; /// A background loop that fetches feature flags from PostHog and updates the feature store. pub struct FeatureResolverBackgroundLoop { @@ -24,9 +24,16 @@ impl FeatureResolverBackgroundLoop { } } - pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { + pub fn spawn( + self: Arc, + handle: &tokio::runtime::Handle, + refresh_period: Duration, + fake_tenants: Vec, + ) { let this = self.clone(); let cancel = self.cancel.clone(); + + // Main loop of updating the feature flags. handle.spawn( async move { tracing::info!("Starting PostHog feature resolver"); @@ -56,6 +63,22 @@ impl FeatureResolverBackgroundLoop { } .instrument(info_span!("posthog_feature_resolver")), ); + + // Report fake tenants to PostHog so that we have the combination of all the properties in the UI. + // Do one report per pageserver restart. + let this = self.clone(); + handle.spawn( + async move { + tracing::info!("Starting PostHog feature reporter"); + for tenant in &fake_tenants { + tracing::info!("Reporting fake tenant: {:?}", tenant); + } + if let Err(e) = this.posthog_client.capture_event_batch(&fake_tenants).await { + tracing::warn!("Cannot report fake tenants: {}", e); + } + } + .instrument(info_span!("posthog_feature_reporter")), + ); } pub fn feature_store(&self) -> Arc { diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index ff12051196..f607b1be0a 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -22,6 +22,16 @@ pub enum PostHogEvaluationError { Internal(String), } +impl PostHogEvaluationError { + pub fn as_variant_str(&self) -> &'static str { + match self { + PostHogEvaluationError::NotAvailable(_) => "not_available", + PostHogEvaluationError::NoConditionGroupMatched => "no_condition_group_matched", + PostHogEvaluationError::Internal(_) => "internal", + } + } +} + #[derive(Deserialize)] pub struct LocalEvaluationResponse { pub flags: Vec, @@ -54,7 +64,7 @@ pub struct LocalEvaluationFlagFilterProperty { operator: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] pub enum PostHogFlagFilterPropertyValue { String(String), @@ -497,6 +507,13 @@ pub struct PostHogClient { client: reqwest::Client, } +#[derive(Serialize, Debug)] +pub struct CaptureEvent { + pub event: String, + pub distinct_id: String, + pub properties: serde_json::Value, +} + impl PostHogClient { pub fn new(config: PostHogClientConfig) -> Self { let client = reqwest::Client::new(); @@ -560,12 +577,12 @@ impl PostHogClient { &self, event: &str, distinct_id: &str, - properties: &HashMap, + properties: &serde_json::Value, ) -> anyhow::Result<()> { // PUBLIC_URL/capture/ - // with bearer token of self.client_api_key let url = format!("{}/capture/", self.config.public_api_url); - self.client + let response = self + .client .post(url) .body(serde_json::to_string(&json!({ "api_key": self.config.client_api_key, @@ -575,6 +592,39 @@ impl PostHogClient { }))?) .send() .await?; + let status = response.status(); + let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to capture events: {}, {}", + status, + body + )); + } + Ok(()) + } + + pub async fn capture_event_batch(&self, events: &[CaptureEvent]) -> anyhow::Result<()> { + // PUBLIC_URL/batch/ + let url = format!("{}/batch/", self.config.public_api_url); + let response = self + .client + .post(url) + .body(serde_json::to_string(&json!({ + "api_key": self.config.client_api_key, + "batch": events, + }))?) + .send() + .await?; + let status = response.status(); + let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to capture events: {}, {}", + status, + body + )); + } Ok(()) } } diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 206b8bbd8f..11f787562c 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -73,6 +73,7 @@ pub mod error; /// async timeout helper pub mod timeout; +pub mod span; pub mod sync; pub mod failpoint_support; diff --git a/libs/utils/src/span.rs b/libs/utils/src/span.rs new file mode 100644 index 0000000000..4dbc99044b --- /dev/null +++ b/libs/utils/src/span.rs @@ -0,0 +1,19 @@ +//! Tracing span helpers. + +/// Records the given fields in the current span, as a single call. The fields must already have +/// been declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record { + ($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)}; +} + +/// Records the given fields in the given span, as a single call. The fields must already have been +/// declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record_in { + ($span:expr, $($tokens:tt)*) => { + if let Some(meta) = $span.metadata() { + $span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*)); + } + }; +} diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index d660602149..4d6cbae9a9 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -439,6 +439,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { currentClusterSize: crate::bindings::pg_atomic_uint64 { value: 0 }, shard_ps_feedback: [empty_feedback; 128], num_shards: 0, + replica_promote: false, min_ps_feedback: empty_feedback, } } diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index c4d6d58945..9591c729e8 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -34,6 +34,7 @@ fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true +http.workspace = true http-utils.workspace = true humantime-serde.workspace = true humantime.workspace = true @@ -93,6 +94,7 @@ tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } tonic.workspace = true tonic-reflection.workspace = true +tower.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true diff --git a/pageserver/benches/bench_metrics.rs b/pageserver/benches/bench_metrics.rs index 38025124e1..e0428f6372 100644 --- a/pageserver/benches/bench_metrics.rs +++ b/pageserver/benches/bench_metrics.rs @@ -264,10 +264,56 @@ mod propagation_of_cached_label_value { } } +criterion_group!(histograms, histograms::bench_bucket_scalability); +mod histograms { + use std::time::Instant; + + use criterion::{BenchmarkId, Criterion}; + use metrics::core::Collector; + + pub fn bench_bucket_scalability(c: &mut Criterion) { + let mut g = c.benchmark_group("bucket_scalability"); + + for n in [1, 4, 8, 16, 32, 64, 128, 256] { + g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| { + b.iter_custom(|iters| { + let buckets: Vec = (0..*n).map(|i| i as f64 * 100.0).collect(); + let histo = metrics::Histogram::with_opts( + metrics::prometheus::HistogramOpts::new("name", "help") + .buckets(buckets.clone()), + ) + .unwrap(); + let start = Instant::now(); + for i in 0..usize::try_from(iters).unwrap() { + histo.observe(buckets[i % buckets.len()]); + } + let elapsed = start.elapsed(); + // self-test + let mfs = histo.collect(); + assert_eq!(mfs.len(), 1); + let metrics = mfs[0].get_metric(); + assert_eq!(metrics.len(), 1); + let histo = metrics[0].get_histogram(); + let buckets = histo.get_bucket(); + assert!( + buckets + .iter() + .enumerate() + .all(|(i, b)| b.get_cumulative_count() + >= i as u64 * (iters / buckets.len() as u64)) + ); + elapsed + }) + }); + } + } +} + criterion_main!( label_values, single_metric_multicore_scalability, - propagation_of_cached_label_value + propagation_of_cached_label_value, + histograms, ); /* @@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns] +bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns] +bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns] +bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns] +bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns] +bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns] +bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns] +bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns] +bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns] Results on an i3en.3xlarge instance @@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns] +bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns] +bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns] +bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns] +bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns] +bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns] +bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns] +bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns] +bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns] Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor @@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns] +bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns] +bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns] +bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns] +bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns] +bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns] +bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns] +bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns] +bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns] */ diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index 4f62c77eb2..e643b5749b 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -9,7 +9,6 @@ bytes.workspace = true pageserver_api.workspace = true postgres_ffi.workspace = true prost.workspace = true -smallvec.workspace = true thiserror.workspace = true tonic.workspace = true utils.workspace = true diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 7ab97a994e..1a08d04cc1 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -9,10 +9,16 @@ //! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits. //! //! - Validate protocol invariants, via try_from() and try_into(). +//! +//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain +//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and +//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g. +//! stream combinators without dealing with errors, and avoids validating the same message twice. + +use std::fmt::Display; use bytes::Bytes; use postgres_ffi::Oid; -use smallvec::SmallVec; // TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid // pulling in all of their other crate dependencies when building the client. use utils::lsn::Lsn; @@ -48,7 +54,8 @@ pub struct ReadLsn { pub request_lsn: Lsn, /// If given, the caller guarantees that the page has not been modified since this LSN. Must be /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page - /// without waiting for the request LSN to arrive. Valid for all request types. + /// without waiting for the request LSN to arrive. If not given, the request will read at the + /// request_lsn and wait for it to arrive if necessary. Valid for all request types. /// /// It is undefined behaviour to make a request such that the page was, in fact, modified /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an @@ -58,19 +65,14 @@ pub struct ReadLsn { pub not_modified_since_lsn: Option, } -impl ReadLsn { - /// Validates the ReadLsn. - pub fn validate(&self) -> Result<(), ProtocolError> { - if self.request_lsn == Lsn::INVALID { - return Err(ProtocolError::invalid("request_lsn", self.request_lsn)); +impl Display for ReadLsn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let req_lsn = self.request_lsn; + if let Some(mod_lsn) = self.not_modified_since_lsn { + write!(f, "{req_lsn}>={mod_lsn}") + } else { + req_lsn.fmt(f) } - if self.not_modified_since_lsn > Some(self.request_lsn) { - return Err(ProtocolError::invalid( - "not_modified_since_lsn", - self.not_modified_since_lsn, - )); - } - Ok(()) } } @@ -78,27 +80,31 @@ impl TryFrom for ReadLsn { type Error = ProtocolError; fn try_from(pb: proto::ReadLsn) -> Result { - let read_lsn = Self { + if pb.request_lsn == 0 { + return Err(ProtocolError::invalid("request_lsn", pb.request_lsn)); + } + if pb.not_modified_since_lsn > pb.request_lsn { + return Err(ProtocolError::invalid( + "not_modified_since_lsn", + pb.not_modified_since_lsn, + )); + } + Ok(Self { request_lsn: Lsn(pb.request_lsn), not_modified_since_lsn: match pb.not_modified_since_lsn { 0 => None, lsn => Some(Lsn(lsn)), }, - }; - read_lsn.validate()?; - Ok(read_lsn) + }) } } -impl TryFrom for proto::ReadLsn { - type Error = ProtocolError; - - fn try_from(read_lsn: ReadLsn) -> Result { - read_lsn.validate()?; - Ok(Self { +impl From for proto::ReadLsn { + fn from(read_lsn: ReadLsn) -> Self { + Self { request_lsn: read_lsn.request_lsn.0, not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0, - }) + } } } @@ -153,6 +159,15 @@ impl TryFrom for CheckRelExistsRequest { } } +impl From for proto::CheckRelExistsRequest { + fn from(request: CheckRelExistsRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), + rel: Some(request.rel.into()), + } + } +} + pub type CheckRelExistsResponse = bool; impl From for CheckRelExistsResponse { @@ -190,14 +205,12 @@ impl TryFrom for GetBaseBackupRequest { } } -impl TryFrom for proto::GetBaseBackupRequest { - type Error = ProtocolError; - - fn try_from(request: GetBaseBackupRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetBaseBackupRequest { + fn from(request: GetBaseBackupRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), replica: request.replica, - }) + } } } @@ -214,14 +227,9 @@ impl TryFrom for GetBaseBackupResponseChunk { } } -impl TryFrom for proto::GetBaseBackupResponseChunk { - type Error = ProtocolError; - - fn try_from(chunk: GetBaseBackupResponseChunk) -> Result { - if chunk.is_empty() { - return Err(ProtocolError::Missing("chunk")); - } - Ok(Self { chunk }) +impl From for proto::GetBaseBackupResponseChunk { + fn from(chunk: GetBaseBackupResponseChunk) -> Self { + Self { chunk } } } @@ -246,14 +254,12 @@ impl TryFrom for GetDbSizeRequest { } } -impl TryFrom for proto::GetDbSizeRequest { - type Error = ProtocolError; - - fn try_from(request: GetDbSizeRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetDbSizeRequest { + fn from(request: GetDbSizeRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), db_oid: request.db_oid, - }) + } } } @@ -288,7 +294,7 @@ pub struct GetPageRequest { /// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access /// costs and parallelizing them. This may increase the latency of any individual request, but /// improves the overall latency and throughput of the batch as a whole. - pub block_numbers: SmallVec<[u32; 1]>, + pub block_numbers: Vec, } impl TryFrom for GetPageRequest { @@ -306,25 +312,20 @@ impl TryFrom for GetPageRequest { .ok_or(ProtocolError::Missing("read_lsn"))? .try_into()?, rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, - block_numbers: pb.block_number.into(), + block_numbers: pb.block_number, }) } } -impl TryFrom for proto::GetPageRequest { - type Error = ProtocolError; - - fn try_from(request: GetPageRequest) -> Result { - if request.block_numbers.is_empty() { - return Err(ProtocolError::Missing("block_number")); - } - Ok(Self { +impl From for proto::GetPageRequest { + fn from(request: GetPageRequest) -> Self { + Self { request_id: request.request_id, request_class: request.request_class.into(), - read_lsn: Some(request.read_lsn.try_into()?), + read_lsn: Some(request.read_lsn.into()), rel: Some(request.rel.into()), - block_number: request.block_numbers.into_vec(), - }) + block_number: request.block_numbers, + } } } @@ -396,7 +397,7 @@ pub struct GetPageResponse { /// A string describing the status, if any. pub reason: Option, /// The 8KB page images, in the same order as the request. Empty if status != OK. - pub page_images: SmallVec<[Bytes; 1]>, + pub page_images: Vec, } impl From for GetPageResponse { @@ -405,7 +406,7 @@ impl From for GetPageResponse { request_id: pb.request_id, status_code: pb.status_code.into(), reason: Some(pb.reason).filter(|r| !r.is_empty()), - page_images: pb.page_image.into(), + page_images: pb.page_image, } } } @@ -416,7 +417,7 @@ impl From for proto::GetPageResponse { request_id: response.request_id, status_code: response.status_code.into(), reason: response.reason.unwrap_or_default(), - page_image: response.page_images.into_vec(), + page_image: response.page_images, } } } @@ -505,14 +506,12 @@ impl TryFrom for GetRelSizeRequest { } } -impl TryFrom for proto::GetRelSizeRequest { - type Error = ProtocolError; - - fn try_from(request: GetRelSizeRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetRelSizeRequest { + fn from(request: GetRelSizeRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), rel: Some(request.rel.into()), - }) + } } } @@ -555,15 +554,13 @@ impl TryFrom for GetSlruSegmentRequest { } } -impl TryFrom for proto::GetSlruSegmentRequest { - type Error = ProtocolError; - - fn try_from(request: GetSlruSegmentRequest) -> Result { - Ok(Self { - read_lsn: Some(request.read_lsn.try_into()?), +impl From for proto::GetSlruSegmentRequest { + fn from(request: GetSlruSegmentRequest) -> Self { + Self { + read_lsn: Some(request.read_lsn.into()), kind: request.kind as u32, segno: request.segno, - }) + } } } @@ -580,14 +577,9 @@ impl TryFrom for GetSlruSegmentResponse { } } -impl TryFrom for proto::GetSlruSegmentResponse { - type Error = ProtocolError; - - fn try_from(segment: GetSlruSegmentResponse) -> Result { - if segment.is_empty() { - return Err(ProtocolError::Missing("segment")); - } - Ok(Self { segment }) +impl From for proto::GetSlruSegmentResponse { + fn from(segment: GetSlruSegmentResponse) -> Self { + Self { segment } } } diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index 5b5ed09a2b..5e4af88e69 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -8,6 +8,8 @@ license.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true +bytes.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true @@ -15,14 +17,17 @@ hdrhistogram.workspace = true humantime.workspace = true humantime-serde.workspace = true rand.workspace = true -reqwest.workspace=true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true tokio.workspace = true +tokio-stream.workspace = true tokio-util.workspace = true +tonic.workspace = true pageserver_client.workspace = true pageserver_api.workspace = true +pageserver_page_api.workspace = true utils = { path = "../../libs/utils/" } workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 50419ec338..3f3b6e396e 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -1,4 +1,4 @@ -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::future::Future; use std::num::NonZeroUsize; use std::pin::Pin; @@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; use camino::Utf8PathBuf; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::reltag::RelTag; use pageserver_api::shard::TenantShardId; +use pageserver_page_api::proto; use rand::prelude::*; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -22,6 +26,12 @@ use utils::lsn::Lsn; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; +#[derive(clap::ValueEnum, Clone, Debug)] +enum Protocol { + Libpq, + Grpc, +} + /// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace. #[derive(clap::Parser)] pub(crate) struct Args { @@ -35,6 +45,8 @@ pub(crate) struct Args { num_clients: NonZeroUsize, #[clap(long)] runtime: Option, + #[clap(long, value_enum, default_value = "libpq")] + protocol: Protocol, /// Each client sends requests at the given rate. /// /// If a request takes too long and we should be issuing a new request already, @@ -65,6 +77,16 @@ pub(crate) struct Args { #[clap(long, default_value = "1")] queue_depth: NonZeroUsize, + /// Batch size of contiguous pages generated by each client. This is equivalent to how Postgres + /// will request page batches (e.g. prefetches or vectored reads). A batch counts as 1 RPS and + /// 1 queue depth. + /// + /// The libpq protocol does not support client-side batching, and will submit batches as many + /// individual requests, in the hope that the server will batch them. Each batch still counts as + /// 1 RPS and 1 queue depth. + #[clap(long, default_value = "1")] + batch_size: NonZeroUsize, + #[clap(long)] only_relnode: Option, @@ -303,7 +325,20 @@ async fn main_impl( .unwrap(); Box::pin(async move { - client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await + let client: Box = match args.protocol { + Protocol::Libpq => Box::new( + LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + + Protocol::Grpc => Box::new( + GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + }; + run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) }; @@ -355,27 +390,28 @@ async fn main_impl( anyhow::Ok(()) } -async fn client_libpq( +async fn run_worker( args: &Args, - worker_id: WorkerId, + mut client: Box, shared_state: Arc, cancel: CancellationToken, rps_period: Option, ranges: Vec, weights: rand::distributions::weighted::WeightedIndex, ) { - let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone()) - .await - .unwrap(); - let mut client = client - .pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id) - .await - .unwrap(); - shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); let mut ticks_processed = 0; - let mut inflight = VecDeque::new(); + let mut req_id = 0; + let batch_size: usize = args.batch_size.into(); + + // Track inflight requests by request ID and start time. This times the request duration, and + // ensures responses match requests. We don't expect responses back in any particular order. + // + // NB: this does not check that all requests received a response, because we don't wait for the + // inflight requests to complete when the duration elapses. + let mut inflight: HashMap = HashMap::new(); + while !cancel.is_cancelled() { // Detect if a request took longer than the RPS rate if let Some(period) = &rps_period { @@ -391,36 +427,72 @@ async fn client_libpq( } while inflight.len() < args.queue_depth.get() { + req_id += 1; let start = Instant::now(); - let req = { + let (req_lsn, mod_lsn, rel, blks) = { + /// Converts a compact i128 key to a relation tag and block number. + fn key_to_block(key: i128) -> (RelTag, u32) { + let key = Key::from_i128(key); + assert!(key.is_rel_block_key()); + key.to_rel_block() + .expect("we filter non-rel-block keys out above") + } + + // Pick a random page from a random relation. let mut rng = rand::thread_rng(); let r = &ranges[weights.sample(&mut rng)]; let key: i128 = rng.gen_range(r.start..r.end); - let key = Key::from_i128(key); - assert!(key.is_rel_block_key()); - let (rel_tag, block_no) = key - .to_rel_block() - .expect("we filter non-rel-block keys out above"); - PagestreamGetPageRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: if rng.gen_bool(args.req_latest_probability) { - Lsn::MAX - } else { - r.timeline_lsn - }, - not_modified_since: r.timeline_lsn, - }, - rel: rel_tag, - blkno: block_no, + let (rel_tag, block_no) = key_to_block(key); + + let mut blks = VecDeque::with_capacity(batch_size); + blks.push_back(block_no); + + // If requested, populate a batch of sequential pages. This is how Postgres will + // request page batches (e.g. prefetches). If we hit the end of the relation, we + // grow the batch towards the start too. + for i in 1..batch_size { + let (r, b) = key_to_block(key + i as i128); + if r != rel_tag { + break; // went outside relation + } + blks.push_back(b) } + + if blks.len() < batch_size { + // Grow batch backwards if needed. + for i in 1..batch_size { + let (r, b) = key_to_block(key - i as i128); + if r != rel_tag { + break; // went outside relation + } + blks.push_front(b) + } + } + + // We assume that the entire batch can fit within the relation. + assert_eq!(blks.len(), batch_size, "incomplete batch"); + + let req_lsn = if rng.gen_bool(args.req_latest_probability) { + Lsn::MAX + } else { + r.timeline_lsn + }; + (req_lsn, r.timeline_lsn, rel_tag, blks.into()) }; - client.getpage_send(req).await.unwrap(); - inflight.push_back(start); + client + .send_get_page(req_id, req_lsn, mod_lsn, rel, blks) + .await + .unwrap(); + let old = inflight.insert(req_id, start); + assert!(old.is_none(), "duplicate request ID {req_id}"); } - let start = inflight.pop_front().unwrap(); - client.getpage_recv().await.unwrap(); + let (req_id, pages) = client.recv_get_page().await.unwrap(); + assert_eq!(pages.len(), batch_size, "unexpected page count"); + assert!(pages.iter().all(|p| !p.is_empty()), "empty page"); + let start = inflight + .remove(&req_id) + .expect("response for unknown request ID"); let end = Instant::now(); shared_state.live_stats.request_done(); ticks_processed += 1; @@ -442,3 +514,154 @@ async fn client_libpq( } } } + +/// A benchmark client, to allow switching out the transport protocol. +/// +/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could +/// return a future that resolves when the response is received, but we don't really need it. +#[async_trait] +trait Client: Send { + /// Sends an asynchronous GetPage request to the pageserver. + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()>; + + /// Receives the next GetPage response from the pageserver. + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)>; +} + +/// A libpq-based Pageserver client. +struct LibpqClient { + inner: pageserver_client::page_service::PagestreamClient, + // Track sent batches, so we know how many responses to expect. + batch_sizes: VecDeque, +} + +impl LibpqClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let inner = pageserver_client::page_service::Client::new(connstring) + .await? + .pagestream(ttid.tenant_id, ttid.timeline_id) + .await?; + Ok(Self { + inner, + batch_sizes: VecDeque::new(), + }) + } +} + +#[async_trait] +impl Client for LibpqClient { + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()> { + // libpq doesn't support client-side batches, so we send a bunch of individual requests + // instead in the hope that the server will batch them for us. We use the same request ID + // for all, because we'll return a single batch response. + self.batch_sizes.push_back(blks.len()); + for blkno in blks { + let req = PagestreamGetPageRequest { + hdr: PagestreamRequest { + reqid: req_id, + request_lsn: req_lsn, + not_modified_since: mod_lsn, + }, + rel, + blkno, + }; + self.inner.getpage_send(req).await?; + } + Ok(()) + } + + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)> { + let batch_size = self.batch_sizes.pop_front().unwrap(); + let mut batch = Vec::with_capacity(batch_size); + let mut req_id = None; + for _ in 0..batch_size { + let resp = self.inner.getpage_recv().await?; + if req_id.is_none() { + req_id = Some(resp.req.hdr.reqid); + } + assert_eq!(req_id, Some(resp.req.hdr.reqid), "request ID mismatch"); + batch.push(resp.page); + } + Ok((req_id.unwrap(), batch)) + } +} + +/// A gRPC client using the raw, no-frills gRPC client. +struct GrpcClient { + req_tx: tokio::sync::mpsc::Sender, + resp_rx: tonic::Streaming, +} + +impl GrpcClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?; + + // The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the + // benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are + // buffered by Tonic and the OS too. + let (req_tx, req_rx) = tokio::sync::mpsc::channel(1); + let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx); + let mut req = tonic::Request::new(req_stream); + let metadata = req.metadata_mut(); + metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?); + metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?); + metadata.insert("neon-shard-id", "0000".try_into()?); + + let resp = client.get_pages(req).await?; + let resp_stream = resp.into_inner(); + + Ok(Self { + req_tx, + resp_rx: resp_stream, + }) + } +} + +#[async_trait] +impl Client for GrpcClient { + async fn send_get_page( + &mut self, + req_id: u64, + req_lsn: Lsn, + mod_lsn: Lsn, + rel: RelTag, + blks: Vec, + ) -> anyhow::Result<()> { + let req = proto::GetPageRequest { + request_id: req_id, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: req_lsn.0, + not_modified_since_lsn: mod_lsn.0, + }), + rel: Some(rel.into()), + block_number: blks, + }; + self.req_tx.send(req).await?; + Ok(()) + } + + async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)> { + let resp = self.resp_rx.message().await?.unwrap(); + anyhow::ensure!( + resp.status_code == proto::GetPageStatusCode::Ok as i32, + "unexpected status code: {}", + resp.status_code + ); + Ok((resp.request_id, resp.page_image)) + } +} diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index e89baa0bce..2a0548b811 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -65,6 +65,30 @@ impl From for BasebackupError { } } +impl From for postgres_backend::QueryError { + fn from(err: BasebackupError) -> Self { + use postgres_backend::QueryError; + use pq_proto::framed::ConnectionError; + match err { + BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)), + BasebackupError::Server(err) => QueryError::Other(err), + BasebackupError::Shutdown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: BasebackupError) -> Self { + use tonic::Code; + let code = match &err { + BasebackupError::Client(_, _) => Code::Cancelled, + BasebackupError::Server(_) => Code::Internal, + BasebackupError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + /// Create basebackup with non-rel data in it. /// Only include relational data if 'full_backup' is true. /// @@ -248,7 +272,7 @@ where async fn flush(&mut self) -> Result<(), BasebackupError> { let nblocks = self.buf.len() / BLCKSZ as usize; let (kind, segno) = self.current_segment.take().unwrap(); - let segname = format!("{}/{:>04X}", kind.to_str(), segno); + let segname = format!("{kind}/{segno:>04X}"); let header = new_tar_header(&segname, self.buf.len() as u64)?; self.ar .append(&header, self.buf.as_slice()) @@ -347,7 +371,7 @@ where .await? .partition( self.timeline.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar); diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index df3c045145..a1a95ad2d1 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -158,7 +158,6 @@ fn main() -> anyhow::Result<()> { // (maybe we should automate this with a visitor?). info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode"); - info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol"); info!(?conf.validate_wal_contiguity, "starting with WAL contiguity validation"); info!(?conf.page_service_pipelining, "starting with page service pipelining config"); info!(?conf.get_vectored_concurrent_io, "starting with get_vectored IO concurrency config"); @@ -804,7 +803,7 @@ fn start_pageserver( } else { None }, - basebackup_cache.clone(), + basebackup_cache, ); // Spawn a Pageserver gRPC server task. It will spawn separate tasks for @@ -816,12 +815,11 @@ fn start_pageserver( let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(page_service::spawn_grpc( - conf, tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), + conf.get_vectored_concurrent_io, grpc_listener, - basebackup_cache, )?); } diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 89f7539722..3492a8d966 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,7 +14,10 @@ use std::time::Duration; use anyhow::{Context, bail, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; -use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig}; +use pageserver_api::config::{ + DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes, + PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig, +}; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; @@ -24,7 +27,6 @@ use reqwest::Url; use storage_broker::Uri; use utils::id::{NodeId, TimelineId}; use utils::logging::{LogFormat, SecretString}; -use utils::postgres_client::PostgresClientProtocol; use crate::tenant::storage_layer::inmemory_layer::IndexEntry; use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME}; @@ -185,6 +187,9 @@ pub struct PageServerConf { pub max_vectored_read_bytes: MaxVectoredReadBytes, + /// Maximum number of keys to be read in a single get_vectored call. + pub max_get_vectored_keys: MaxGetVectoredKeys, + pub image_compression: ImageCompressionAlgorithm, /// Whether to offload archived timelines automatically @@ -205,8 +210,6 @@ pub struct PageServerConf { /// Optionally disable disk syncs (unsafe!) pub no_sync: bool, - pub wal_receiver_protocol: PostgresClientProtocol, - pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig, pub get_vectored_concurrent_io: pageserver_api::config::GetVectoredConcurrentIo, @@ -404,6 +407,7 @@ impl PageServerConf { secondary_download_concurrency, ingest_batch_size, max_vectored_read_bytes, + max_get_vectored_keys, image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, @@ -414,7 +418,6 @@ impl PageServerConf { virtual_file_io_engine, tenant_config, no_sync, - wal_receiver_protocol, page_service_pipelining, get_vectored_concurrent_io, enable_read_path_debugging, @@ -470,13 +473,13 @@ impl PageServerConf { secondary_download_concurrency, ingest_batch_size, max_vectored_read_bytes, + max_get_vectored_keys, image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, import_pgdata_upcall_api, import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from), import_pgdata_aws_endpoint_url, - wal_receiver_protocol, page_service_pipelining, get_vectored_concurrent_io, tracing, @@ -598,6 +601,19 @@ impl PageServerConf { ) })?; + if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined { + max_batch_size, + .. + }) = conf.page_service_pipelining + { + if max_batch_size.get() > conf.max_get_vectored_keys.get() { + return Err(anyhow::anyhow!( + "`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})", + conf.max_get_vectored_keys.get() + )); + } + }; + Ok(conf) } @@ -685,6 +701,7 @@ impl ConfigurableSemaphore { mod tests { use camino::Utf8PathBuf; + use rstest::rstest; use utils::id::NodeId; use super::PageServerConf; @@ -724,4 +741,28 @@ mod tests { PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir) .expect_err("parse_and_validate should fail for endpoint without scheme"); } + + #[rstest] + #[case(32, 32, true)] + #[case(64, 32, false)] + #[case(64, 64, true)] + #[case(128, 128, true)] + fn test_config_max_batch_size_is_valid( + #[case] max_batch_size: usize, + #[case] max_get_vectored_keys: usize, + #[case] is_valid: bool, + ) { + let input = format!( + r#" + control_plane_api = "http://localhost:6666" + max_get_vectored_keys = {max_get_vectored_keys} + page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }} + "#, + ); + let config_toml = toml_edit::de::from_str::(&input) + .expect("config has valid fields"); + let workdir = Utf8PathBuf::from("/nonexistent"); + let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir); + assert_eq!(result.is_ok(), is_valid); + } } diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 7e31b930d0..50de3b691c 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -1,21 +1,28 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use posthog_client_lite::{ - FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, + CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, + PostHogFlagFilterPropertyValue, }; +use remote_storage::RemoteStorageKind; +use serde_json::json; use tokio_util::sync::CancellationToken; use utils::id::TenantId; -use crate::config::PageServerConf; +use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION}; #[derive(Clone)] pub struct FeatureResolver { inner: Option>, + internal_properties: Option>>, } impl FeatureResolver { pub fn new_disabled() -> Self { - Self { inner: None } + Self { + inner: None, + internal_properties: None, + } } pub fn spawn( @@ -36,14 +43,114 @@ impl FeatureResolver { shutdown_pageserver, ); let inner = Arc::new(inner); - // TODO: make this configurable - inner.clone().spawn(handle, Duration::from_secs(60)); - Ok(FeatureResolver { inner: Some(inner) }) + + // The properties shared by all tenants on this pageserver. + let internal_properties = { + let mut properties = HashMap::new(); + properties.insert( + "pageserver_id".to_string(), + PostHogFlagFilterPropertyValue::String(conf.id.to_string()), + ); + if let Some(availability_zone) = &conf.availability_zone { + properties.insert( + "availability_zone".to_string(), + PostHogFlagFilterPropertyValue::String(availability_zone.clone()), + ); + } + // Infer region based on the remote storage config. + if let Some(remote_storage) = &conf.remote_storage_config { + match &remote_storage.storage { + RemoteStorageKind::AwsS3(config) => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String(format!( + "aws-{}", + config.bucket_region + )), + ); + } + RemoteStorageKind::AzureContainer(config) => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String(format!( + "azure-{}", + config.container_region + )), + ); + } + RemoteStorageKind::LocalFs { .. } => { + properties.insert( + "region".to_string(), + PostHogFlagFilterPropertyValue::String("local".to_string()), + ); + } + } + } + // TODO: add pageserver URL. + Arc::new(properties) + }; + let fake_tenants = { + let mut tenants = Vec::new(); + for i in 0..10 { + let distinct_id = format!( + "fake_tenant_{}_{}_{}", + conf.availability_zone.as_deref().unwrap_or_default(), + conf.id, + i + ); + let properties = Self::collect_properties_inner( + distinct_id.clone(), + Some(&internal_properties), + ); + tenants.push(CaptureEvent { + event: "initial_tenant_report".to_string(), + distinct_id, + properties: json!({ "$set": properties }), // use `$set` to set the person properties instead of the event properties + }); + } + tenants + }; + // TODO: make refresh period configurable + inner + .clone() + .spawn(handle, Duration::from_secs(60), fake_tenants); + Ok(FeatureResolver { + inner: Some(inner), + internal_properties: Some(internal_properties), + }) } else { - Ok(FeatureResolver { inner: None }) + Ok(FeatureResolver { + inner: None, + internal_properties: None, + }) } } + fn collect_properties_inner( + tenant_id: String, + internal_properties: Option<&HashMap>, + ) -> HashMap { + let mut properties = HashMap::new(); + if let Some(internal_properties) = internal_properties { + for (key, value) in internal_properties.iter() { + properties.insert(key.clone(), value.clone()); + } + } + properties.insert( + "tenant_id".to_string(), + PostHogFlagFilterPropertyValue::String(tenant_id), + ); + properties + } + + /// Collect all properties availble for the feature flag evaluation. + pub(crate) fn collect_properties( + &self, + tenant_id: TenantId, + ) -> HashMap { + Self::collect_properties_inner(tenant_id.to_string(), self.internal_properties.as_deref()) + } + /// Evaluate a multivariate feature flag. Currently, we do not support any properties. /// /// Error handling: the caller should inspect the error and decide the behavior when a feature flag @@ -55,11 +162,24 @@ impl FeatureResolver { tenant_id: TenantId, ) -> Result { if let Some(inner) = &self.inner { - inner.feature_store().evaluate_multivariate( + let res = inner.feature_store().evaluate_multivariate( flag_key, &tenant_id.to_string(), - &HashMap::new(), - ) + &self.collect_properties(tenant_id), + ); + match &res { + Ok(value) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "ok", value]) + .inc(); + } + Err(e) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "error", e.as_variant_str()]) + .inc(); + } + } + res } else { Err(PostHogEvaluationError::NotAvailable( "PostHog integration is not enabled".to_string(), @@ -80,11 +200,24 @@ impl FeatureResolver { tenant_id: TenantId, ) -> Result<(), PostHogEvaluationError> { if let Some(inner) = &self.inner { - inner.feature_store().evaluate_boolean( + let res = inner.feature_store().evaluate_boolean( flag_key, &tenant_id.to_string(), - &HashMap::new(), - ) + &self.collect_properties(tenant_id), + ); + match &res { + Ok(()) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "ok", "true"]) + .inc(); + } + Err(e) => { + FEATURE_FLAG_EVALUATION + .with_label_values(&[flag_key, "error", e.as_variant_str()]) + .inc(); + } + } + res } else { Err(PostHogEvaluationError::NotAvailable( "PostHog integration is not enabled".to_string(), diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 1effa10404..c8a2a0209f 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -43,6 +43,7 @@ use pageserver_api::models::{ use pageserver_api::shard::{ShardCount, TenantShardId}; use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError}; use scopeguard::defer; +use serde_json::json; use tenant_size_model::svg::SvgBranchKind; use tenant_size_model::{SizeResult, StorageModel}; use tokio::time::Instant; @@ -3679,23 +3680,24 @@ async fn tenant_evaluate_feature_flag( let tenant = state .tenant_manager .get_attached_tenant_shard(tenant_shard_id)?; + let properties = tenant.feature_resolver.collect_properties(tenant_shard_id.tenant_id); if as_type == "boolean" { let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); let result = result.map(|_| true).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else if as_type == "multivariate" { let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else { // Auto infer the type of the feature flag. let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?; if is_boolean { let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); let result = result.map(|_| true).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else { let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); - json_response(StatusCode::OK, result) + json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } } } diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 0ff31dcb8a..3eb70ffac2 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -15,6 +15,7 @@ use metrics::{ register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec, }; use once_cell::sync::Lazy; +use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS; use pageserver_api::config::{ PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy, @@ -32,7 +33,6 @@ use crate::config::PageServerConf; use crate::context::{PageContentKind, RequestContext}; use crate::pgdatadir_mapping::DatadirModificationStats; use crate::task_mgr::TaskKind; -use crate::tenant::Timeline; use crate::tenant::layer_map::LayerMap; use crate::tenant::mgr::TenantSlot; use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc}; @@ -446,6 +446,15 @@ static PAGE_CACHE_ERRORS: Lazy = Lazy::new(|| { .expect("failed to define a metric") }); +pub(crate) static FEATURE_FLAG_EVALUATION: Lazy = Lazy::new(|| { + register_counter_vec!( + "pageserver_feature_flag_evaluation", + "Number of times a feature flag is evaluated", + &["flag_key", "status", "value"], + ) + .unwrap() +}); + #[derive(IntoStaticStr)] #[strum(serialize_all = "kebab_case")] pub(crate) enum PageCacheErrorKind { @@ -1312,11 +1321,44 @@ impl EvictionsWithLowResidenceDuration { // // Roughly logarithmic scale. const STORAGE_IO_TIME_BUCKETS: &[f64] = &[ - 0.000030, // 30 usec - 0.001000, // 1000 usec - 0.030, // 30 ms - 1.000, // 1000 ms - 30.000, // 30000 ms + 0.00005, // 50us + 0.00006, // 60us + 0.00007, // 70us + 0.00008, // 80us + 0.00009, // 90us + 0.0001, // 100us + 0.000110, // 110us + 0.000120, // 120us + 0.000130, // 130us + 0.000140, // 140us + 0.000150, // 150us + 0.000160, // 160us + 0.000170, // 170us + 0.000180, // 180us + 0.000190, // 190us + 0.000200, // 200us + 0.000210, // 210us + 0.000220, // 220us + 0.000230, // 230us + 0.000240, // 240us + 0.000250, // 250us + 0.000300, // 300us + 0.000350, // 350us + 0.000400, // 400us + 0.000450, // 450us + 0.000500, // 500us + 0.000600, // 600us + 0.000700, // 700us + 0.000800, // 800us + 0.000900, // 900us + 0.001000, // 1ms + 0.002000, // 2ms + 0.003000, // 3ms + 0.004000, // 4ms + 0.005000, // 5ms + 0.01000, // 10ms + 0.02000, // 20ms + 0.05000, // 50ms ]; /// VirtualFile fs operation variants. @@ -1906,7 +1948,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy = Lazy::new(|| { }); static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy> = Lazy::new(|| { - (1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap()) + (1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap()) .map(|v| v.into()) .collect() }); @@ -1924,7 +1966,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy> = Lazy::new( let mut buckets = Vec::new(); for i in 0.. { let bucket = 1 << i; - if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() { + if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() { break; } buckets.push(bucket.into()); @@ -2813,7 +2855,6 @@ pub(crate) struct WalIngestMetrics { pub(crate) records_received: IntCounter, pub(crate) records_observed: IntCounter, pub(crate) records_committed: IntCounter, - pub(crate) records_filtered: IntCounter, pub(crate) values_committed_metadata_images: IntCounter, pub(crate) values_committed_metadata_deltas: IntCounter, pub(crate) values_committed_data_images: IntCounter, @@ -2869,11 +2910,6 @@ pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| { "Number of WAL records which resulted in writes to pageserver storage" ) .expect("failed to define a metric"), - records_filtered: register_int_counter!( - "pageserver_wal_ingest_records_filtered", - "Number of WAL records filtered out due to sharding" - ) - .expect("failed to define a metric"), values_committed_metadata_images: values_committed.with_label_values(&["metadata", "image"]), values_committed_metadata_deltas: values_committed.with_label_values(&["metadata", "delta"]), values_committed_data_images: values_committed.with_label_values(&["data", "image"]), diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e96787e027..4a1ddf09b5 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -1,18 +1,21 @@ //! The Page Service listens for client connections and serves their GetPage@LSN //! requests. +use std::any::Any; use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context, bail}; +use anyhow::{Context as _, anyhow, bail}; use async_compression::tokio::write::GzipEncoder; -use bytes::Buf; +use bytes::{Buf, BytesMut}; +use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; @@ -31,6 +34,7 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api as page_api; use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, @@ -39,14 +43,14 @@ use postgres_ffi::BLCKSZ; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; use pq_proto::framed::ConnectionError; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor}; +use smallvec::{SmallVec, smallvec}; use strum_macros::IntoStaticStr; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; -use utils::failpoint_support; use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; @@ -54,6 +58,7 @@ use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; +use utils::{failpoint_support, span_record}; use crate::auth::check_permission; use crate::basebackup::{self, BasebackupError}; @@ -76,7 +81,8 @@ use crate::tenant::mgr::{ GetActiveTenantError, GetTenantError, ShardResolveResult, ShardSelector, TenantManager, }; use crate::tenant::storage_layer::IoConcurrency; -use crate::tenant::timeline::{self, WaitLsnError}; +use crate::tenant::timeline::handle::{Handle, HandleUpgradeError, WeakHandle}; +use crate::tenant::timeline::{self, WaitLsnError, WaitLsnTimeout, WaitLsnWaiter}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; @@ -165,15 +171,15 @@ pub fn spawn( /// Spawns a gRPC server for the page service. /// +/// TODO: move this onto GrpcPageServiceHandler::spawn(). /// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we /// need to reimplement the TCP+TLS accept loop ourselves. pub fn spawn_grpc( - conf: &'static PageServerConf, tenant_manager: Arc, auth: Option>, perf_trace_dispatch: Option, + get_vectored_concurrent_io: GetVectoredConcurrentIo, listener: std::net::TcpListener, - basebackup_cache: Arc, ) -> anyhow::Result { let cancel = CancellationToken::new(); let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) @@ -195,35 +201,40 @@ pub fn spawn_grpc( // Set up the gRPC server. // // TODO: consider tuning window sizes. - // TODO: wire up tracing. let mut server = tonic::transport::Server::builder() .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); - // Main page service. - let page_service_handler = PageServerHandler::new( + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. + let page_service_handler = GrpcPageServiceHandler { tenant_manager, - auth.clone(), - PageServicePipeliningConfig::Serial, // TODO: unused with gRPC - conf.get_vectored_concurrent_io, - ConnectionPerfSpanFields::default(), - basebackup_cache, ctx, - cancel.clone(), - gate.enter().expect("just created"), - ); - - let mut tenant_interceptor = TenantMetadataInterceptor; - let mut auth_interceptor = TenantAuthInterceptor::new(auth); - let interceptors = move |mut req: tonic::Request<()>| { - req = tenant_interceptor.call(req)?; - req = auth_interceptor.call(req)?; - Ok(req) + gate_guard: gate.enter().expect("gate was just created"), + get_vectored_concurrent_io, }; - let page_service = - proto::PageServiceServer::with_interceptor(page_service_handler, interceptors); + let observability_layer = ObservabilityLayer; + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + + let page_service = tower::ServiceBuilder::new() + // Create tracing span and record request start time. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + .service(proto::PageServiceServer::new(page_service_handler)); let server = server.add_service(page_service); // Reflection service for use with e.g. grpcurl. @@ -489,10 +500,6 @@ async fn page_service_conn_main( } /// Page service connection handler. -/// -/// TODO: for gRPC, this will be shared by all requests from all connections. -/// Decompose it into global state and per-connection/request state, and make -/// libpq-specific options (e.g. pipelining) separate. struct PageServerHandler { auth: Option>, claims: Option, @@ -542,7 +549,7 @@ impl TimelineHandles { tenant_id: TenantId, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result, GetActiveTimelineError> { + ) -> Result, GetActiveTimelineError> { if *self.wrapper.tenant_id.get_or_init(|| tenant_id) != tenant_id { return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::SwitchedTenant, @@ -709,6 +716,82 @@ enum PageStreamError { BadRequest(Cow<'static, str>), } +impl PageStreamError { + /// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status + /// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a + /// convenience method for use from a get_pages gRPC stream. + #[allow(clippy::result_large_err)] + fn into_get_page_response( + self, + request_id: page_api::RequestID, + ) -> Result { + use page_api::GetPageStatusCode; + use tonic::Code; + + // We dispatch to Into first, and then map it to a GetPageResponse. + let status: tonic::Status = self.into(); + let status_code = match status.code() { + // We shouldn't see an OK status here, because we're emitting an error. + Code::Ok => { + debug_assert_ne!(status.code(), Code::Ok); + return Err(tonic::Status::internal(format!( + "unexpected OK status: {status:?}", + ))); + } + + // These are per-request errors, returned as GetPageResponses. + Code::AlreadyExists => GetPageStatusCode::InvalidRequest, + Code::DataLoss => GetPageStatusCode::InternalError, + Code::FailedPrecondition => GetPageStatusCode::InvalidRequest, + Code::InvalidArgument => GetPageStatusCode::InvalidRequest, + Code::Internal => GetPageStatusCode::InternalError, + Code::NotFound => GetPageStatusCode::NotFound, + Code::OutOfRange => GetPageStatusCode::InvalidRequest, + Code::ResourceExhausted => GetPageStatusCode::SlowDown, + + // These should terminate the stream. + Code::Aborted => return Err(status), + Code::Cancelled => return Err(status), + Code::DeadlineExceeded => return Err(status), + Code::PermissionDenied => return Err(status), + Code::Unauthenticated => return Err(status), + Code::Unavailable => return Err(status), + Code::Unimplemented => return Err(status), + Code::Unknown => return Err(status), + }; + + Ok(page_api::GetPageResponse { + request_id, + status_code, + reason: Some(status.message().to_string()), + page_images: Vec::new(), + } + .into()) + } +} + +impl From for tonic::Status { + fn from(err: PageStreamError) -> Self { + use tonic::Code; + let message = err.to_string(); + let code = match err { + PageStreamError::Reconnect(_) => Code::Unavailable, + PageStreamError::Shutdown => Code::Unavailable, + PageStreamError::Read(err) => match err { + PageReconstructError::Cancelled => Code::Unavailable, + PageReconstructError::MissingKey(_) => Code::NotFound, + PageReconstructError::AncestorLsnTimeout(err) => tonic::Status::from(err).code(), + PageReconstructError::Other(_) => Code::Internal, + PageReconstructError::WalRedo(_) => Code::Internal, + }, + PageStreamError::LsnTimeout(err) => tonic::Status::from(err).code(), + PageStreamError::NotFound(_) => Code::NotFound, + PageStreamError::BadRequest(_) => Code::InvalidArgument, + }; + tonic::Status::new(code, message) + } +} + impl From for PageStreamError { fn from(value: PageReconstructError) -> Self { match value { @@ -789,37 +872,37 @@ enum BatchedFeMessage { Exists { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamExistsRequest, }, Nblocks { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamNblocksRequest, }, GetPage { span: Span, - shard: timeline::handle::WeakHandle, - pages: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + shard: WeakHandle, + pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, DbSize { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamGetSlruSegmentRequest, }, #[cfg(feature = "testing")] Test { span: Span, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, requests: Vec, }, RespondError { @@ -1068,26 +1151,6 @@ impl PageServerHandler { let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; - // TODO: turn in to async closure once available to avoid repeating received_at - async fn record_op_start_and_throttle( - shard: &timeline::handle::Handle, - op: metrics::SmgrQueryType, - received_at: Instant, - ) -> Result { - // It's important to start the smgr op metric recorder as early as possible - // so that the _started counters are incremented before we do - // any serious waiting, e.g., for throttle, batching, or actual request handling. - let mut timer = shard.query_metrics.start_smgr_op(op, received_at); - let now = Instant::now(); - timer.observe_throttle_start(now); - let throttled = tokio::select! { - res = shard.pagestream_throttle.throttle(1, now) => res, - _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), - }; - timer.observe_throttle_done(throttled); - Ok(timer) - } - let batched_msg = match neon_fe_msg { PagestreamFeMessage::Exists(req) => { let shard = timeline_handles @@ -1095,7 +1158,7 @@ impl PageServerHandler { .await?; debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); let span = tracing::info_span!(parent: &parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelExists, received_at, @@ -1113,7 +1176,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelSize, received_at, @@ -1131,7 +1194,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetDbSize, received_at, @@ -1149,7 +1212,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetSlruSegment, received_at, @@ -1274,7 +1337,7 @@ impl PageServerHandler { // request handler log messages contain the request-specific fields. let span = mkspan!(shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetPageAtLsn, received_at, @@ -1321,7 +1384,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), - pages: smallvec::smallvec![BatchedGetPageRequest { + pages: smallvec![BatchedGetPageRequest { req, timer, lsn_range: LsnRange { @@ -1343,9 +1406,12 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_test_request", shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = - record_op_start_and_throttle(&shard, metrics::SmgrQueryType::Test, received_at) - .await?; + let timer = Self::record_op_start_and_throttle( + &shard, + metrics::SmgrQueryType::Test, + received_at, + ) + .await?; BatchedFeMessage::Test { span, shard: shard.downgrade(), @@ -1356,6 +1422,26 @@ impl PageServerHandler { Ok(Some(batched_msg)) } + /// Starts a SmgrOpTimer at received_at and throttles the request. + async fn record_op_start_and_throttle( + shard: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + // It's important to start the smgr op metric recorder as early as possible + // so that the _started counters are incremented before we do + // any serious waiting, e.g., for throttle, batching, or actual request handling. + let mut timer = shard.query_metrics.start_smgr_op(op, received_at); + let now = Instant::now(); + timer.observe_throttle_start(now); + let throttled = tokio::select! { + res = shard.pagestream_throttle.throttle(1, now) => res, + _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), + }; + timer.observe_throttle_done(throttled); + Ok(timer) + } + /// Post-condition: `batch` is Some() #[instrument(skip_all, level = tracing::Level::TRACE)] #[allow(clippy::boxed_local)] @@ -1453,8 +1539,11 @@ impl PageServerHandler { let (mut handler_results, span) = { // TODO: we unfortunately have to pin the future on the heap, since GetPage futures are huge and // won't fit on the stack. - let mut boxpinned = - Box::pin(self.pagestream_dispatch_batched_message(batch, io_concurrency, ctx)); + let mut boxpinned = Box::pin(Self::pagestream_dispatch_batched_message( + batch, + io_concurrency, + ctx, + )); log_slow( log_slow_name, LOG_SLOW_GETPAGE_THRESHOLD, @@ -1610,7 +1699,6 @@ impl PageServerHandler { /// Helper which dispatches a batched message to the appropriate handler. /// Returns a vec of results, along with the extracted trace span. async fn pagestream_dispatch_batched_message( - &mut self, batch: BatchedFeMessage, io_concurrency: IoConcurrency, ctx: &RequestContext, @@ -1640,10 +1728,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_rel_exists_request(&shard, &req, &ctx) + Self::handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Exists(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1659,10 +1747,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_nblocks_request(&shard, &req, &ctx) + Self::handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1680,16 +1768,15 @@ impl PageServerHandler { { let npages = pages.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_get_page_at_lsn_request_batched( - &shard, - pages, - io_concurrency, - batch_break_reason, - &ctx, - ) - .instrument(span.clone()) - .await; + let res = Self::handle_get_page_at_lsn_request_batched( + &shard, + pages, + io_concurrency, + batch_break_reason, + &ctx, + ) + .instrument(span.clone()) + .await; assert_eq!(res.len(), npages); res }, @@ -1706,10 +1793,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_db_size_request(&shard, &req, &ctx) + Self::handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::DbSize(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1725,10 +1812,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_slru_segment_request(&shard, &req, &ctx) + Self::handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer, ctx)) + .map(|msg| (PagestreamBeMessage::GetSlruSegment(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1746,8 +1833,7 @@ impl PageServerHandler { { let npages = requests.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_test_request_batch(&shard, requests, &ctx) + let res = Self::handle_test_request_batch(&shard, requests, &ctx) .instrument(span.clone()) .await; assert_eq!(res.len(), npages); @@ -2301,11 +2387,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_rel_exists_request( - &mut self, timeline: &Timeline, req: &PagestreamExistsRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2327,19 +2412,15 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse { - req: *req, - exists, - })) + Ok(PagestreamExistsResponse { req: *req, exists }) } #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( - &mut self, timeline: &Timeline, req: &PagestreamNblocksRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2361,19 +2442,18 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse { + Ok(PagestreamNblocksResponse { req: *req, n_blocks, - })) + }) } #[instrument(skip_all, fields(shard_id))] async fn handle_db_size_request( - &mut self, timeline: &Timeline, req: &PagestreamDbSizeRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2397,17 +2477,13 @@ impl PageServerHandler { .await?; let db_size = total_blocks as i64 * BLCKSZ as i64; - Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse { - req: *req, - db_size, - })) + Ok(PagestreamDbSizeResponse { req: *req, db_size }) } #[instrument(skip_all)] async fn handle_get_page_at_lsn_request_batched( - &mut self, timeline: &Timeline, - requests: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + requests: SmallVec<[BatchedGetPageRequest; 1]>, io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, @@ -2532,11 +2608,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_slru_segment_request( - &mut self, timeline: &Timeline, req: &PagestreamGetSlruSegmentRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2551,16 +2626,13 @@ impl PageServerHandler { .ok_or(PageStreamError::BadRequest("invalid SLRU kind".into()))?; let segment = timeline.get_slru_segment(kind, req.segno, lsn, ctx).await?; - Ok(PagestreamBeMessage::GetSlruSegment( - PagestreamGetSlruSegmentResponse { req: *req, segment }, - )) + Ok(PagestreamGetSlruSegmentResponse { req: *req, segment }) } // NB: this impl mimics what we do for batched getpage requests. #[cfg(feature = "testing")] #[instrument(skip_all, fields(shard_id))] async fn handle_test_request_batch( - &mut self, timeline: &Timeline, requests: Vec, _ctx: &RequestContext, @@ -2636,15 +2708,6 @@ impl PageServerHandler { where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { - fn map_basebackup_error(err: BasebackupError) -> QueryError { - match err { - // TODO: passthrough the error site to the final error message? - BasebackupError::Client(e, _) => QueryError::Disconnected(ConnectionError::Io(e)), - BasebackupError::Server(e) => QueryError::Other(e), - BasebackupError::Shutdown => QueryError::Shutdown, - } - } - let started = std::time::Instant::now(); let timeline = self @@ -2702,8 +2765,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } else { let mut writer = BufWriter::new(pgb.copyout_writer()); @@ -2726,11 +2788,8 @@ impl PageServerHandler { from_cache = true; tokio::io::copy(&mut cached, &mut writer) .await - .map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,cached,copy", - )) + .map_err(|err| { + BasebackupError::Client(err, "handle_basebackup_request,cached,copy") })?; } else if gzip { let mut encoder = GzipEncoder::with_quality( @@ -2751,8 +2810,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; // shutdown the encoder to ensure the gzip footer is written encoder .shutdown() @@ -2768,15 +2826,12 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } - writer.flush().await.map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,flush", - )) - })?; + writer + .flush() + .await + .map_err(|err| BasebackupError::Client(err, "handle_basebackup_request,flush"))?; } pgb.write_message_noflush(&BeMessage::CopyDone) @@ -3300,80 +3355,543 @@ where } } -/// Implements the page service over gRPC. +/// Serves the page service over gRPC. Dispatches to PageServerHandler for request processing. /// -/// TODO: not yet implemented, all methods return unimplemented. -#[tonic::async_trait] -impl proto::PageService for PageServerHandler { - type GetBaseBackupStream = Pin< - Box> + Send>, - >; - type GetPagesStream = - Pin> + Send>>; +/// TODO: rename to PageServiceHandler when libpq impl is removed. +pub struct GrpcPageServiceHandler { + tenant_manager: Arc, + ctx: RequestContext, + gate_guard: GateGuard, + get_vectored_concurrent_io: GetVectoredConcurrentIo, +} - async fn check_rel_exists( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) +impl GrpcPageServiceHandler { + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of + /// relations and their sizes, as well as SLRU segments and similar data. + #[allow(clippy::result_large_err)] + fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { + match timeline.get_shard_index().shard_number.0 { + 0 => Ok(()), + shard => Err(tonic::Status::invalid_argument(format!( + "request must execute on shard zero (is shard {shard})", + ))), + } } - async fn get_base_backup( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + /// Generates a PagestreamRequest header from a ReadLsn and request ID. + fn make_hdr(read_lsn: page_api::ReadLsn, req_id: u64) -> PagestreamRequest { + PagestreamRequest { + reqid: req_id, + request_lsn: read_lsn.request_lsn, + not_modified_since: read_lsn + .not_modified_since_lsn + .unwrap_or(read_lsn.request_lsn), + } } - async fn get_db_size( + /// Acquires a timeline handle for the given request. + /// + /// TODO: during shard splits, the compute may still be sending requests to the parent shard + /// until the entire split is committed and the compute is notified. Consider installing a + /// temporary shard router from the parent to the children while the split is in progress. + /// + /// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage + /// the TimelineHandles lifecycle. + /// + /// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid + /// the unnecessary overhead. + async fn get_request_timeline( &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + req: &tonic::Request, + ) -> Result, GetActiveTimelineError> { + let ttid = *extract::(req); + let shard_index = *extract::(req); + let shard_selector = ShardSelector::Known(shard_index); + + TimelineHandles::new(self.tenant_manager.clone()) + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await } - async fn get_pages( - &self, - _: tonic::Request>, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + /// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start. + /// Only errors if the timeline is shutting down. + /// + /// TODO: move timer construction to ObservabilityLayer (see TODO there). + /// TODO: decouple rate limiting (middleware?), and return SlowDown errors instead. + async fn record_op_start_and_throttle( + timeline: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + let mut timer = PageServerHandler::record_op_start_and_throttle(timeline, op, received_at) + .await + .map_err(|err| match err { + // record_op_start_and_throttle() only returns Shutdown. + QueryError::Shutdown => tonic::Status::unavailable(format!("{err}")), + err => tonic::Status::internal(format!("unexpected error: {err}")), + })?; + timer.observe_execution_start(Instant::now()); + Ok(timer) } - async fn get_rel_size( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) - } + /// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC. + /// + /// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse + /// with an appropriate status code instead. + /// + /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send + /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or + /// split them up in the client or server. + #[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))] + async fn get_page( + ctx: &RequestContext, + timeline: &WeakHandle, + req: proto::GetPageRequest, + io_concurrency: IoConcurrency, + ) -> Result { + let received_at = Instant::now(); + let timeline = timeline.upgrade()?; + let ctx = ctx.with_scope_page_service_pagestream(&timeline); - async fn get_slru_segment( - &self, - _: tonic::Request, - ) -> Result, tonic::Status> { - Err(tonic::Status::unimplemented("not implemented")) + // Validate the request, decorate the span, and convert it to a Pagestream request. + let req: page_api::GetPageRequest = req.try_into()?; + + span_record!( + req_id = %req.request_id, + rel = %req.rel, + blkno = %req.block_numbers[0], + blks = %req.block_numbers.len(), + lsn = %req.read_lsn, + ); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard + let effective_lsn = match PageServerHandler::effective_request_lsn( + &timeline, + timeline.get_last_record_lsn(), + req.read_lsn.request_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), + &latest_gc_cutoff_lsn, + ) { + Ok(lsn) => lsn, + Err(err) => return err.into_get_page_response(req.request_id), + }; + + let mut batch = SmallVec::with_capacity(req.block_numbers.len()); + for blkno in req.block_numbers { + // TODO: this creates one timer per page and throttles it. We should have a timer for + // the entire batch, and throttle only the batch, but this is equivalent to what + // PageServerHandler does already so we keep it for now. + let timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetPageAtLsn, + received_at, + ) + .await?; + + batch.push(BatchedGetPageRequest { + req: PagestreamGetPageRequest { + hdr: Self::make_hdr(req.read_lsn, req.request_id), + rel: req.rel, + blkno, + }, + lsn_range: LsnRange { + effective_lsn, + request_lsn: req.read_lsn.request_lsn, + }, + timer, + ctx: ctx.attached_child(), + batch_wait_ctx: None, // TODO: add tracing + }); + } + + // TODO: this does a relation size query for every page in the batch. Since this batch is + // all for one relation, we could do this only once. However, this is not the case for the + // libpq implementation. + let results = PageServerHandler::handle_get_page_at_lsn_request_batched( + &timeline, + batch, + io_concurrency, + GetPageBatchBreakReason::BatchFull, // TODO: not relevant for gRPC batches + &ctx, + ) + .await; + + let mut resp = page_api::GetPageResponse { + request_id: req.request_id, + status_code: page_api::GetPageStatusCode::Ok, + reason: None, + page_images: Vec::with_capacity(results.len()), + }; + + for result in results { + match result { + Ok((PagestreamBeMessage::GetPage(r), _, _)) => resp.page_images.push(r.page), + Ok((resp, _, _)) => { + return Err(tonic::Status::internal(format!( + "unexpected response: {resp:?}" + ))); + } + Err(err) => return err.err.into_get_page_response(req.request_id), + }; + } + + Ok(resp.into()) } } -impl From for QueryError { - fn from(e: GetActiveTenantError) -> Self { - match e { - GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), - ), - GetActiveTenantError::Cancelled - | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { - QueryError::Shutdown - } - e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), - e => QueryError::Other(anyhow::anyhow!(e)), +/// Implements the gRPC page service. +/// +/// TODO: cancellation. +/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code. +#[tonic::async_trait] +impl proto::PageService for GrpcPageServiceHandler { + type GetBaseBackupStream = Pin< + Box> + Send>, + >; + + type GetPagesStream = + Pin> + Send>>; + + #[instrument(skip_all, fields(rel, lsn))] + async fn check_rel_exists( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamExistsRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelExists, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?; + let resp: page_api::CheckRelExistsResponse = resp.exists; + Ok(tonic::Response::new(resp.into())) + } + + // TODO: ensure clients use gzip compression for the stream. + #[instrument(skip_all, fields(lsn))] + async fn get_base_backup( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + // Send 64 KB chunks to avoid large memory allocations. + const CHUNK_SIZE: usize = 64 * 1024; + + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_timeline(&timeline); + + // Validate the request, decorate the span, and wait for the LSN to arrive. + // + // TODO: this requires a read LSN, is that ok? + Self::ensure_shard_zero(&timeline)?; + if timeline.is_archived() == Some(true) { + return Err(tonic::Status::failed_precondition("timeline is archived")); } + let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?; + + span_record!(lsn=%req.read_lsn); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); + timeline + .wait_lsn( + req.read_lsn.request_lsn, + WaitLsnWaiter::PageService, + WaitLsnTimeout::Default, + &ctx, + ) + .await?; + timeline + .check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn) + .map_err(|err| { + tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) + })?; + + // Spawn a task to run the basebackup. + // + // TODO: do we need to support full base backups, for debugging? + let span = Span::current(); + let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE); + let jh = tokio::spawn(async move { + let result = basebackup::send_basebackup_tarball( + &mut simplex_write, + &timeline, + Some(req.read_lsn.request_lsn), + None, + false, + req.replica, + &ctx, + ) + .instrument(span) // propagate request span + .await; + simplex_write.shutdown().await.map_err(|err| { + BasebackupError::Server(anyhow!("simplex shutdown failed: {err}")) + })?; + result + }); + + // Emit chunks of size CHUNK_SIZE. + let chunks = async_stream::try_stream! { + let mut chunk = BytesMut::with_capacity(CHUNK_SIZE); + loop { + let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { + tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) + })?; + + // If we read 0 bytes, either the chunk is full or the stream is closed. + if n == 0 { + if chunk.is_empty() { + break; + } + yield proto::GetBaseBackupResponseChunk::from(chunk.clone().freeze()); + chunk.clear(); + } + } + // Wait for the basebackup task to exit and check for errors. + jh.await.map_err(|err| { + tonic::Status::internal(format!("basebackup failed: {err}")) + })??; + }; + + Ok(tonic::Response::new(Box::pin(chunks))) + } + + #[instrument(skip_all, fields(db_oid, lsn))] + async fn get_db_size( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?; + + span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn); + + let req = PagestreamDbSizeRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + dbnode: req.db_oid, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetDbSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_db_size_request(&timeline, &req, &ctx).await?; + let resp = resp.db_size as page_api::GetDbSizeResponse; + Ok(tonic::Response::new(resp.into())) + } + + // NB: don't instrument this, instrument each streamed request. + async fn get_pages( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + // Extract the timeline from the request and check that it exists. + let ttid = *extract::(&req); + let shard_index = *extract::(&req); + let shard_selector = ShardSelector::Known(shard_index); + + let mut handles = TimelineHandles::new(self.tenant_manager.clone()); + handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await?; + + // Spawn an IoConcurrency sidecar, if enabled. + let Ok(gate_guard) = self.gate_guard.try_clone() else { + return Err(tonic::Status::unavailable("shutting down")); + }; + let io_concurrency = + IoConcurrency::spawn_from_conf(self.get_vectored_concurrent_io, gate_guard); + + // Spawn a task to handle the GetPageRequest stream. + let span = Span::current(); + let ctx = self.ctx.attached_child(); + let mut reqs = req.into_inner(); + + let resps = async_stream::try_stream! { + let timeline = handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await? + .downgrade(); + while let Some(req) = reqs.message().await? { + yield Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) + .instrument(span.clone()) // propagate request span + .await? + } + }; + + Ok(tonic::Response::new(Box::pin(resps))) + } + + #[instrument(skip_all, fields(rel, lsn))] + async fn get_rel_size( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamNblocksRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetRelSizeResponse = resp.n_blocks; + Ok(tonic::Response::new(resp.into())) + } + + #[instrument(skip_all, fields(kind, segno, lsn))] + async fn get_slru_segment( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; + + span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); + + let req = PagestreamGetSlruSegmentRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + kind: req.kind as u8, + segno: req.segno, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetSlruSegment, + received_at, + ) + .await?; + + let resp = + PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetSlruSegmentResponse = resp.segment; + Ok(tonic::Response::new(resp.into())) + } +} + +/// gRPC middleware layer that handles observability concerns: +/// +/// * Creates and enters a tracing span. +/// * Records the request start time as a ReceivedAt request extension. +/// +/// TODO: add perf tracing. +/// TODO: add timing and metrics. +/// TODO: add logging. +#[derive(Clone)] +struct ObservabilityLayer; + +impl tower::Layer for ObservabilityLayer { + type Service = ObservabilityLayerService; + + fn layer(&self, inner: S) -> Self::Service { + Self::Service { inner } + } +} + +#[derive(Clone)] +struct ObservabilityLayerService { + inner: S, +} + +#[derive(Clone, Copy)] +struct ReceivedAt(Instant); + +impl tonic::server::NamedService for ObservabilityLayerService { + const NAME: &'static str = S::NAME; // propagate inner service name +} + +impl tower::Service> for ObservabilityLayerService +where + S: tower::Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn call(&mut self, mut req: http::Request) -> Self::Future { + // Record the request start time as a request extension. + // + // TODO: we should start a timer here instead, but it currently requires a timeline handle + // and SmgrQueryType, which we don't have yet. Refactor it to provide it later. + req.extensions_mut().insert(ReceivedAt(Instant::now())); + + // Create a basic tracing span. Enter the span for the current thread (to use it for inner + // sync code like interceptors), and instrument the future (to use it for inner async code + // like the page service itself). + // + // The instrument() call below is not sufficient. It only affects the returned future, and + // only takes effect when the caller polls it. Any sync code executed when we call + // self.inner.call() below (such as interceptors) runs outside of the returned future, and + // is not affected by it. We therefore have to enter the span on the current thread too. + let span = info_span!( + "grpc:pageservice", + // Set by TenantMetadataInterceptor. + tenant_id = field::Empty, + timeline_id = field::Empty, + shard_id = field::Empty, + ); + let _guard = span.enter(); + + Box::pin(self.inner.call(req).instrument(span.clone())) + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) } } /// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type /// TenantTimelineId and ShardIndex. -/// -/// TODO: consider looking up the timeline handle here and storing it. #[derive(Clone)] struct TenantMetadataInterceptor; @@ -3400,25 +3918,28 @@ impl tonic::service::Interceptor for TenantMetadataInterceptor { .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; // Decode the shard ID. - let shard_index = req + let shard_id = req .metadata() .get("neon-shard-id") .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? .to_str() .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; - let shard_index = ShardIndex::from_str(shard_index) + let shard_id = ShardIndex::from_str(shard_id) .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; // Stash them in the request. let extensions = req.extensions_mut(); extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); - extensions.insert(shard_index); + extensions.insert(shard_id); + + // Decorate the tracing span. + span_record!(%tenant_id, %timeline_id, %shard_id); Ok(req) } } -/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor. +/// Authenticates gRPC page service requests. #[derive(Clone)] struct TenantAuthInterceptor { auth: Option>, @@ -3437,11 +3958,8 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { return Ok(req); }; - // Fetch the tenant ID that's been set by TenantMetadataInterceptor. - let ttid = req - .extensions() - .get::() - .expect("TenantMetadataInterceptor must run before TenantAuthInterceptor"); + // Fetch the tenant ID from the request extensions (set by TenantMetadataInterceptor). + let TenantTimelineId { tenant_id, .. } = *extract::(&req); // Fetch and decode the JWT token. let jwt = req @@ -3459,7 +3977,7 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { let claims = jwtdata.claims; // Check if the token is valid for this tenant. - check_permission(&claims, Some(ttid.tenant_id)) + check_permission(&claims, Some(tenant_id)) .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; // TODO: consider stashing the claims in the request extensions, if needed. @@ -3468,6 +3986,21 @@ impl tonic::service::Interceptor for TenantAuthInterceptor { } } +/// Extracts the given type from the request extensions, or panics if it is missing. +fn extract(req: &tonic::Request) -> &T { + extract_from(req.extensions()) +} + +/// Extract the given type from the request extensions, or panics if it is missing. This variant +/// can extract both from a tonic::Request and http::Request. +fn extract_from(ext: &http::Extensions) -> &T { + let Some(value) = ext.get::() else { + let name = std::any::type_name::(); + panic!("extension {name} should be set by middleware"); + }; + value +} + #[derive(Debug, thiserror::Error)] pub(crate) enum GetActiveTimelineError { #[error(transparent)] @@ -3486,10 +4019,72 @@ impl From for QueryError { } } -impl From for QueryError { - fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { +impl From for tonic::Status { + fn from(err: GetActiveTimelineError) -> Self { + let message = err.to_string(); + let code = match err { + GetActiveTimelineError::Tenant(err) => tonic::Status::from(err).code(), + GetActiveTimelineError::Timeline(err) => tonic::Status::from(err).code(), + }; + tonic::Status::new(code, message) + } +} + +impl From for tonic::Status { + fn from(err: GetTimelineError) -> Self { + use tonic::Code; + let code = match &err { + GetTimelineError::NotFound { .. } => Code::NotFound, + GetTimelineError::NotActive { .. } => Code::Unavailable, + GetTimelineError::ShuttingDown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: GetActiveTenantError) -> Self { match e { - crate::tenant::timeline::handle::HandleUpgradeError::ShutDown => QueryError::Shutdown, + GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ), + GetActiveTenantError::Cancelled + | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { + QueryError::Shutdown + } + e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), + e => QueryError::Other(anyhow::anyhow!(e)), + } + } +} + +impl From for tonic::Status { + fn from(err: GetActiveTenantError) -> Self { + use tonic::Code; + let code = match &err { + GetActiveTenantError::Broken(_) => Code::Internal, + GetActiveTenantError::Cancelled => Code::Unavailable, + GetActiveTenantError::NotFound(_) => Code::NotFound, + GetActiveTenantError::SwitchedTenant => Code::Unavailable, + GetActiveTenantError::WaitForActiveTimeout { .. } => Code::Unavailable, + GetActiveTenantError::WillNotBecomeActive(_) => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: HandleUpgradeError) -> Self { + match e { + HandleUpgradeError::ShutDown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: HandleUpgradeError) -> Self { + match err { + HandleUpgradeError::ShutDown => tonic::Status::unavailable("timeline is shutting down"), } } } diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c6f3929257..633d62210d 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -431,10 +431,10 @@ impl Timeline { GetVectoredError::InvalidLsn(e) => { Err(anyhow::anyhow!("invalid LSN: {e:?}").into()) } - // NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS + // NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys // TODO: we can prevent this error class by moving this check into the type system - GetVectoredError::Oversized(err) => { - Err(anyhow::anyhow!("batching oversized: {err:?}").into()) + GetVectoredError::Oversized(err, max) => { + Err(anyhow::anyhow!("batching oversized: {err} > {max}").into()) } }; @@ -471,8 +471,19 @@ impl Timeline { let rels = self.list_rels(spcnode, dbnode, version, ctx).await?; + if rels.is_empty() { + return Ok(0); + } + + // Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`. + let reldir_key = rel_dir_to_key(spcnode, dbnode); + let buf = version.get(self, reldir_key, ctx).await?; + let reldir = RelDirectory::des(&buf)?; + for rel in rels { - let n_blocks = self.get_rel_size(rel, version, ctx).await?; + let n_blocks = self + .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx) + .await?; total_blocks += n_blocks as usize; } Ok(total_blocks) @@ -487,6 +498,19 @@ impl Timeline { tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_size_in_reldir(tag, version, None, ctx).await + } + + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`. + pub(crate) async fn get_rel_size_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -499,7 +523,9 @@ impl Timeline { } if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM) - && !self.get_rel_exists(tag, version, ctx).await? + && !self + .get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx) + .await? { // FIXME: Postgres sometimes calls smgrcreate() to create // FSM, and smgrnblocks() on it immediately afterwards, @@ -521,11 +547,28 @@ impl Timeline { /// /// Only shard 0 has a full view of the relations. Other shards only know about relations that /// the shard stores pages for. + /// pub(crate) async fn get_rel_exists( &self, tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_exists_in_reldir(tag, version, None, ctx).await + } + + /// Does the relation exist? With a cached deserialized `RelDirectory`. + /// + /// There are some cases where the caller loops across all relations. In that specific case, + /// the caller should obtain the deserialized `RelDirectory` first and then call this function + /// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing + /// a new API (e.g., `get_rel_exists_batched`). + pub(crate) async fn get_rel_exists_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -568,6 +611,17 @@ impl Timeline { // fetch directory listing (old) let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } let buf = version.get(self, key, ctx).await?; let dir = RelDirectory::des(&buf)?; @@ -665,7 +719,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( @@ -905,7 +959,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), - Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64, + self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 308ada3fa1..f9fdc143b4 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -7197,7 +7197,7 @@ mod tests { let end = desc .key_range .start - .add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap()); + .add(tenant.conf.max_get_vectored_keys.get() as u32); reads.push(KeySpace { ranges: vec![start..end], }); @@ -11260,11 +11260,11 @@ mod tests { let mut keyspaces_at_lsn: HashMap = HashMap::default(); let mut used_keys: HashSet = HashSet::default(); - while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize { + while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty"); let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE)); - while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize { + while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { if used_keys.contains(&selected_key) || selected_key >= start_key.add(KEY_DIMENSION_SIZE) { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 23c40a7629..3522af2de0 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError { #[error("timeline shutting down")] Cancelled, - #[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)] - Oversized(u64), + #[error("requested too many keys: {0} > {1}")] + Oversized(u64, u64), #[error("requested at invalid LSN: {0}")] InvalidLsn(Lsn), @@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError { Timeout(String), } +impl From for tonic::Status { + fn from(err: WaitLsnError) -> Self { + use tonic::Code; + let code = match &err { + WaitLsnError::Timeout(_) => Code::Internal, + WaitLsnError::BadState(_) => Code::Internal, + WaitLsnError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + // The impls below achieve cancellation mapping for errors. // Perhaps there's a way of achieving this with less cruft. @@ -1007,7 +1019,7 @@ impl From for PageReconstructError { match e { GetVectoredError::Cancelled => PageReconstructError::Cancelled, GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")), - err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()), + err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()), GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err), GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err), GetVectoredError::Other(err) => PageReconstructError::Other(err), @@ -1187,7 +1199,6 @@ impl Timeline { } } - pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32; pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100; /// Look up multiple page versions at a given LSN @@ -1202,9 +1213,12 @@ impl Timeline { ) -> Result>, GetVectoredError> { let total_keyspace = query.total_keyspace(); - let key_count = total_keyspace.total_raw_size().try_into().unwrap(); - if key_count > Timeline::MAX_GET_VECTORED_KEYS { - return Err(GetVectoredError::Oversized(key_count)); + let key_count = total_keyspace.total_raw_size(); + if key_count > self.conf.max_get_vectored_keys.get() { + return Err(GetVectoredError::Oversized( + key_count as u64, + self.conf.max_get_vectored_keys.get() as u64, + )); } for range in &total_keyspace.ranges { @@ -2492,6 +2506,13 @@ impl Timeline { // Preparing basebackup doesn't make sense for shards other than shard zero. return; } + if !self.is_active() { + // May happen during initial timeline creation. + // Such timeline is not in the global timeline map yet, + // so basebackup cache will not be able to find it. + // TODO(diko): We can prepare such timelines in finish_creation(). + return; + } let res = self .basebackup_prepare_sender @@ -2831,21 +2852,6 @@ impl Timeline { ) } - /// Resolve the effective WAL receiver protocol to use for this tenant. - /// - /// Priority order is: - /// 1. Tenant config override - /// 2. Default value for tenant config override - /// 3. Pageserver config override - /// 4. Pageserver config default - pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol { - let tenant_conf = self.tenant_conf.load().tenant_conf.clone(); - tenant_conf - .wal_receiver_protocol_override - .or(self.conf.default_tenant_conf.wal_receiver_protocol_override) - .unwrap_or(self.conf.wal_receiver_protocol) - } - pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) { // NB: Most tenant conf options are read by background loops, so, // changes will automatically be picked up. @@ -3201,10 +3207,16 @@ impl Timeline { guard.is_none(), "multiple launches / re-launches of WAL receiver are not supported" ); + + let protocol = PostgresClientProtocol::Interpreted { + format: utils::postgres_client::InterpretedFormat::Protobuf, + compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), + }; + *guard = Some(WalReceiver::start( Arc::clone(self), WalReceiverConf { - protocol: self.resolve_wal_receiver_protocol(), + protocol, wal_connect_timeout, lagging_wal_timeout, max_lsn_wal_lag, @@ -5258,7 +5270,7 @@ impl Timeline { key = key.next(); // Maybe flush `key_rest_accum` - if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS + if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64 || (last_key_in_range && key_request_accum.raw_size() > 0) { let query = diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index f19a4b3e9c..606ad09ef1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -106,6 +106,8 @@ pub async fn doit( ); } + tracing::info!("Import plan executed. Flushing remote changes and notifying storcon"); + timeline .remote_client .schedule_index_upload_for_file_changes()?; @@ -199,8 +201,8 @@ async fn prepare_import( .await; match res { Ok(_) => break, - Err(err) => { - info!(?err, "indefinitely waiting for pgdata to finish"); + Err(_err) => { + info!("indefinitely waiting for pgdata to finish"); if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled()) .await .is_ok() diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 9743aa3f26..e003bb6810 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -11,19 +11,7 @@ //! - => S3 as the source for the PGDATA instead of local filesystem //! //! TODOs before productionization: -//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding. -//! => produced image layers likely too small. //! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size. -//! - asserts / unwraps need to be replaced with errors -//! - don't trust remote objects will be small (=prevent OOMs in those cases) -//! - limit all in-memory buffers in size, or download to disk and read from there -//! - limit task concurrency -//! - generally play nice with other tenants in the system -//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits -//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc -//! - integrate with layer eviction system -//! - audit for Tenant::cancel nor Timeline::cancel responsivity -//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!) //! //! An incomplete set of TODOs from the Hackathon: //! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest) @@ -44,7 +32,7 @@ use pageserver_api::key::{ rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, slru_segment_size_to_key, }; -use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range}; +use pageserver_api::keyspace::{ShardedRange, singleton_range}; use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus}; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; @@ -112,6 +100,7 @@ async fn run_v1( .unwrap(), import_job_concurrency: base.import_job_concurrency, import_job_checkpoint_threshold: base.import_job_checkpoint_threshold, + import_job_max_byte_range_size: base.import_job_max_byte_range_size, } } None => timeline.conf.timeline_import_config.clone(), @@ -142,7 +131,15 @@ async fn run_v1( pausable_failpoint!("import-timeline-pre-execute-pausable"); + let jobs_count = import_progress.as_ref().map(|p| p.jobs); let start_from_job_idx = import_progress.map(|progress| progress.completed); + + tracing::info!( + start_from_job_idx=?start_from_job_idx, + jobs=?jobs_count, + "Executing import plan" + ); + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -167,6 +164,7 @@ impl Planner { /// This function is and must remain pure: given the same input, it will generate the same import plan. async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result { let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align(); + anyhow::ensure!(pgdata_lsn.is_valid()); let datadir = PgDataDir::new(&self.storage).await?; @@ -249,14 +247,22 @@ impl Planner { }); // Assigns parts of key space to later parallel jobs + // Note: The image layers produced here may have gaps, meaning, + // there is not an image for each key in the layer's key range. + // The read path stops traversal at the first image layer, regardless + // of whether a base image has been found for a key or not. + // (Concept of sparse image layers doesn't exist.) + // This behavior is exactly right for the base image layers we're producing here. + // But, since no other place in the code currently produces image layers with gaps, + // it seems noteworthy. let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); let mut current_chunk_size: usize = 0; let mut jobs = Vec::new(); for task in std::mem::take(&mut self.tasks).into_iter() { - if current_chunk_size + task.total_size() - > import_config.import_job_soft_size_limit.into() - { + let task_size = task.total_size(&self.shard); + let projected_chunk_size = current_chunk_size.saturating_add(task_size); + if projected_chunk_size > import_config.import_job_soft_size_limit.into() { let key_range = last_end_key..task.key_range().start; jobs.push(ChunkProcessingJob::new( key_range.clone(), @@ -266,7 +272,7 @@ impl Planner { last_end_key = key_range.end; current_chunk_size = 0; } - current_chunk_size += task.total_size(); + current_chunk_size = current_chunk_size.saturating_add(task_size); current_chunk.push(task); } jobs.push(ChunkProcessingJob::new( @@ -436,6 +442,7 @@ impl Plan { let mut last_completed_job_idx = start_after_job_idx.unwrap_or(0); let checkpoint_every: usize = import_config.import_job_checkpoint_threshold.into(); + let max_byte_range_size: usize = import_config.import_job_max_byte_range_size.into(); // Run import jobs concurrently up to the limit specified by the pageserver configuration. // Note that we process completed futures in the oreder of insertion. This will be the @@ -451,7 +458,7 @@ impl Plan { work.push_back(tokio::task::spawn(async move { let _permit = permit; - let res = job.run(job_timeline, &ctx).await; + let res = job.run(job_timeline, max_byte_range_size, &ctx).await; (job_idx, res) })); }, @@ -466,6 +473,8 @@ impl Plan { last_completed_job_idx = job_idx; if last_completed_job_idx % checkpoint_every == 0 { + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); + let progress = ShardImportProgressV1 { jobs: jobs_in_plan, completed: last_completed_job_idx, @@ -604,18 +613,18 @@ impl PgDataDirDb { }; let path = datadir_path.join(rel_tag.to_segfile_name(segno)); - assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error + anyhow::ensure!(filesize % BLCKSZ as usize == 0); let nblocks = filesize / BLCKSZ as usize; - PgDataDirDbFile { + Ok(PgDataDirDbFile { path, filesize, rel_tag, segno, nblocks: Some(nblocks), // first non-cummulative sizes - } + }) }) - .collect(); + .collect::>()?; // Set cummulative sizes. Do all of that math here, so that later we could easier // parallelize over segments and know with which segments we need to write relsize @@ -650,18 +659,29 @@ impl PgDataDirDb { trait ImportTask { fn key_range(&self) -> Range; - fn total_size(&self) -> usize { - // TODO: revisit this - if is_contiguous_range(&self.key_range()) { - contiguous_range_len(&self.key_range()) as usize * 8192 + fn total_size(&self, shard_identity: &ShardIdentity) -> usize { + let range = ShardedRange::new(self.key_range(), shard_identity); + let page_count = range.page_count(); + if page_count == u32::MAX { + tracing::warn!( + "Import task has non contiguous key range: {}..{}", + self.key_range().start, + self.key_range().end + ); + + // Tasks should operate on contiguous ranges. It is unexpected for + // ranges to violate this assumption. Calling code handles this by mapping + // any task on a non contiguous range to its own image layer. + usize::MAX } else { - u32::MAX as usize + page_count as usize * 8192 } } async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result; } @@ -698,6 +718,7 @@ impl ImportTask for ImportSingleKeyTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + _max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { layer_writer.put_image(self.key, self.buf, ctx).await?; @@ -751,6 +772,7 @@ impl ImportTask for ImportRelBlocksTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { debug!("Importing relation file"); @@ -777,7 +799,7 @@ impl ImportTask for ImportRelBlocksTask { assert_eq!(key.len(), 1); assert!(!acc.is_empty()); assert!(acc_end > acc_start); - if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ { + if acc_end == start && end - acc_start <= max_byte_range_size { acc.push(key.pop().unwrap()); Ok((acc, acc_start, end)) } else { @@ -792,8 +814,8 @@ impl ImportTask for ImportRelBlocksTask { .get_range(&self.path, range_start.into_u64(), range_end.into_u64()) .await?; let mut buf = Bytes::from(range_buf); - // TODO: batched writes for key in keys { + // The writer buffers writes internally let image = buf.split_to(8192); layer_writer.put_image(key, image, ctx).await?; nimages += 1; @@ -841,11 +863,15 @@ impl ImportTask for ImportSlruBlocksTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + _max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { debug!("Importing SLRU segment file {}", self.path); let buf = self.storage.get(&self.path).await?; + // TODO(vlad): Does timestamp to LSN work for imported timelines? + // Probably not since we don't append the `xact_time` to it as in + // [`WalIngest::ingest_xact_record`]. let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?; let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?; let mut blknum = start_blk; @@ -884,12 +910,13 @@ impl ImportTask for AnyImportTask { async fn doit( self, layer_writer: &mut ImageLayerWriter, + max_byte_range_size: usize, ctx: &RequestContext, ) -> anyhow::Result { match self { - Self::SingleKey(t) => t.doit(layer_writer, ctx).await, - Self::RelBlocks(t) => t.doit(layer_writer, ctx).await, - Self::SlruBlocks(t) => t.doit(layer_writer, ctx).await, + Self::SingleKey(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, + Self::RelBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, + Self::SlruBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await, } } } @@ -930,7 +957,12 @@ impl ChunkProcessingJob { } } - async fn run(self, timeline: Arc, ctx: &RequestContext) -> anyhow::Result<()> { + async fn run( + self, + timeline: Arc, + max_byte_range_size: usize, + ctx: &RequestContext, + ) -> anyhow::Result<()> { let mut writer = ImageLayerWriter::new( timeline.conf, timeline.timeline_id, @@ -945,7 +977,7 @@ impl ChunkProcessingJob { let mut nimages = 0; for task in self.tasks { - nimages += task.doit(&mut writer, ctx).await?; + nimages += task.doit(&mut writer, max_byte_range_size, ctx).await?; } let resident_layer = if nimages > 0 { diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index 34313748b7..bf2d9875c1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use postgres_ffi::ControlFileData; use remote_storage::{ Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, - ListingObject, RemotePath, + ListingObject, RemotePath, RemoteStorageConfig, }; use serde::de::DeserializeOwned; use tokio_util::sync::CancellationToken; @@ -22,11 +22,9 @@ pub async fn new( location: &index_part_format::Location, cancel: CancellationToken, ) -> Result { - // FIXME: we probably want some timeout, and we might be able to assume the max file - // size on S3 is 1GiB (postgres segment size). But the problem is that the individual - // downloaders don't know enough about concurrent downloads to make a guess on the - // expected bandwidth and resulting best timeout. - let timeout = std::time::Duration::from_secs(24 * 60 * 60); + // Downloads should be reasonably sized. We do ranged reads for relblock raw data + // and full reads for SLRU segments which are bounded by Postgres. + let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT; let location_storage = match location { #[cfg(feature = "testing")] index_part_format::Location::LocalFs { path } => { @@ -50,9 +48,12 @@ pub async fn new( .import_pgdata_aws_endpoint_url .clone() .map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env - concurrency_limit: 100.try_into().unwrap(), // TODO: think about this - max_keys_per_list_response: Some(1000), // TODO: think about this - upload_storage_class: None, // irrelevant + // This matches the default import job concurrency. This is managed + // separately from the usual S3 client, but the concern here is bandwidth + // usage. + concurrency_limit: 128.try_into().unwrap(), + max_keys_per_list_response: Some(1000), + upload_storage_class: None, // irrelevant }, timeout, ) diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 0f73eb839b..633c94a010 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -113,7 +113,7 @@ impl WalReceiver { } connection_manager_state.shutdown().await; *loop_status.write().unwrap() = None; - debug!("task exits"); + info!("task exits"); } .instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id)) }); diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index 3c3608d1bd..7e0b0e9b25 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -32,9 +32,7 @@ use utils::backoff::{ }; use utils::id::{NodeId, TenantTimelineId}; use utils::lsn::Lsn; -use utils::postgres_client::{ - ConnectionConfigArgs, PostgresClientProtocol, wal_stream_connection_config, -}; +use utils::postgres_client::{ConnectionConfigArgs, wal_stream_connection_config}; use super::walreceiver_connection::{WalConnectionStatus, WalReceiverError}; use super::{TaskEvent, TaskHandle, TaskStateUpdate, WalReceiverConf}; @@ -991,19 +989,12 @@ impl ConnectionManagerState { return None; // no connection string, ignore sk } - let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol { - PostgresClientProtocol::Vanilla => { - (None, None, None) - }, - PostgresClientProtocol::Interpreted { .. } => { - let shard_identity = self.timeline.get_shard_identity(); - ( - Some(shard_identity.number.0), - Some(shard_identity.count.0), - Some(shard_identity.stripe_size.0), - ) - } - }; + let shard_identity = self.timeline.get_shard_identity(); + let (shard_number, shard_count, shard_stripe_size) = ( + Some(shard_identity.number.0), + Some(shard_identity.count.0), + Some(shard_identity.stripe_size.0), + ); let connection_conf_args = ConnectionConfigArgs { protocol: self.conf.protocol, @@ -1120,8 +1111,8 @@ impl ReconnectReason { #[cfg(test)] mod tests { - use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL; use url::Host; + use utils::postgres_client::PostgresClientProtocol; use super::*; use crate::tenant::harness::{TIMELINE_ID, TenantHarness}; @@ -1552,6 +1543,11 @@ mod tests { .await .expect("Failed to create an empty timeline for dummy wal connection manager"); + let protocol = PostgresClientProtocol::Interpreted { + format: utils::postgres_client::InterpretedFormat::Protobuf, + compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }), + }; + ConnectionManagerState { id: TenantTimelineId { tenant_id: harness.tenant_shard_id.tenant_id, @@ -1560,7 +1556,7 @@ mod tests { timeline, cancel: CancellationToken::new(), conf: WalReceiverConf { - protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, + protocol, wal_connect_timeout: Duration::from_secs(1), lagging_wal_timeout: Duration::from_secs(1), max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 52259f205b..343e04f5f0 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -15,7 +15,7 @@ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::WAL_SEGMENT_SIZE; use postgres_ffi::v14::xlog_utils::normalize_lsn; -use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder}; +use postgres_ffi::waldecoder::WalDecodeError; use postgres_protocol::message::backend::ReplicationMessage; use postgres_types::PgLsn; use tokio::sync::watch; @@ -31,7 +31,7 @@ use utils::lsn::Lsn; use utils::pageserver_feedback::PageserverFeedback; use utils::postgres_client::PostgresClientProtocol; use utils::sync::gate::GateError; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords}; +use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecords}; use wal_decoder::wire_format::FromWireFormat; use super::TaskStateUpdate; @@ -275,8 +275,6 @@ pub(super) async fn handle_walreceiver_connection( let copy_stream = replication_client.copy_both_simple(&query).await?; let mut physical_stream = pin!(ReplicationStream::new(copy_stream)); - let mut waldecoder = WalStreamDecoder::new(startpoint, timeline.pg_version); - let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx) .await .map_err(|e| match e.kind { @@ -284,19 +282,22 @@ pub(super) async fn handle_walreceiver_connection( _ => WalReceiverError::Other(e.into()), })?; - let shard = vec![*timeline.get_shard_identity()]; - - let interpreted_proto_config = match protocol { - PostgresClientProtocol::Vanilla => None, + let (format, compression) = match protocol { PostgresClientProtocol::Interpreted { format, compression, - } => Some((format, compression)), + } => (format, compression), + PostgresClientProtocol::Vanilla => { + return Err(WalReceiverError::Other(anyhow!( + "Vanilla WAL receiver protocol is no longer supported for ingest" + ))); + } }; let mut expected_wal_start = startpoint; while let Some(replication_message) = { select! { + biased; _ = cancellation.cancelled() => { debug!("walreceiver interrupted"); None @@ -312,16 +313,6 @@ pub(super) async fn handle_walreceiver_connection( // Update the connection status before processing the message. If the message processing // fails (e.g. in walingest), we still want to know latests LSNs from the safekeeper. match &replication_message { - ReplicationMessage::XLogData(xlog_data) => { - connection_status.latest_connection_update = now; - connection_status.commit_lsn = Some(Lsn::from(xlog_data.wal_end())); - connection_status.streaming_lsn = Some(Lsn::from( - xlog_data.wal_start() + xlog_data.data().len() as u64, - )); - if !xlog_data.data().is_empty() { - connection_status.latest_wal_update = now; - } - } ReplicationMessage::PrimaryKeepAlive(keepalive) => { connection_status.latest_connection_update = now; connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end())); @@ -352,7 +343,6 @@ pub(super) async fn handle_walreceiver_connection( // were interpreted. let streaming_lsn = Lsn::from(raw.streaming_lsn()); - let (format, compression) = interpreted_proto_config.unwrap(); let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression) .await .with_context(|| { @@ -508,138 +498,6 @@ pub(super) async fn handle_walreceiver_connection( Some(streaming_lsn) } - ReplicationMessage::XLogData(xlog_data) => { - async fn commit( - modification: &mut DatadirModification<'_>, - uncommitted: &mut u64, - filtered: &mut u64, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - let stats = modification.stats(); - modification.commit(ctx).await?; - WAL_INGEST - .records_committed - .inc_by(*uncommitted - *filtered); - WAL_INGEST.inc_values_committed(&stats); - *uncommitted = 0; - *filtered = 0; - Ok(()) - } - - // Pass the WAL data to the decoder, and see if we can decode - // more records as a result. - let data = xlog_data.data(); - let startlsn = Lsn::from(xlog_data.wal_start()); - let endlsn = startlsn + data.len() as u64; - - trace!("received XLogData between {startlsn} and {endlsn}"); - - WAL_INGEST.bytes_received.inc_by(data.len() as u64); - waldecoder.feed_bytes(data); - - { - let mut modification = timeline.begin_modification(startlsn); - let mut uncommitted_records = 0; - let mut filtered_records = 0; - - while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? { - // It is important to deal with the aligned records as lsn in getPage@LSN is - // aligned and can be several bytes bigger. Without this alignment we are - // at risk of hitting a deadlock. - if !next_record_lsn.is_aligned() { - return Err(WalReceiverError::Other(anyhow!("LSN not aligned"))); - } - - // Deserialize and interpret WAL record - let interpreted = InterpretedWalRecord::from_bytes_filtered( - recdata, - &shard, - next_record_lsn, - modification.tline.pg_version, - )? - .remove(timeline.get_shard_identity()) - .unwrap(); - - if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) - && uncommitted_records > 0 - { - // Special case: legacy PG database creations operate by reading pages from a 'template' database: - // these are the only kinds of WAL record that require reading data blocks while ingesting. Ensure - // all earlier writes of data blocks are visible by committing any modification in flight. - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - - // Ingest the records without immediately committing them. - timeline.metrics.wal_records_received.inc(); - let ingested = walingest - .ingest_record(interpreted, &mut modification, &ctx) - .await - .with_context(|| { - format!("could not ingest record at {next_record_lsn}") - }) - .inspect_err(|err| { - // TODO: we can't differentiate cancellation errors with - // anyhow::Error, so just ignore it if we're cancelled. - if !cancellation.is_cancelled() && !timeline.is_stopping() { - critical!("{err:?}") - } - })?; - if !ingested { - tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}"); - WAL_INGEST.records_filtered.inc(); - filtered_records += 1; - } - - // FIXME: this cannot be made pausable_failpoint without fixing the - // failpoint library; in tests, the added amount of debugging will cause us - // to timeout the tests. - fail_point!("walreceiver-after-ingest"); - - last_rec_lsn = next_record_lsn; - - // Commit every ingest_batch_size records. Even if we filtered out - // all records, we still need to call commit to advance the LSN. - uncommitted_records += 1; - if uncommitted_records >= ingest_batch_size - || modification.approx_pending_bytes() - > DatadirModification::MAX_PENDING_BYTES - { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - // Commit the remaining records. - if uncommitted_records > 0 { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - if !caught_up && endlsn >= end_of_wal { - info!("caught up at LSN {endlsn}"); - caught_up = true; - } - - Some(endlsn) - } - ReplicationMessage::PrimaryKeepAlive(keepalive) => { let wal_end = keepalive.wal_end(); let timestamp = keepalive.timestamp(); diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index a6a7021756..5b4ced7cf0 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -16,6 +16,7 @@ #if PG_MAJORVERSION_NUM >= 15 #include "access/xlogrecovery.h" #endif +#include "executor/instrument.h" #include "replication/logical.h" #include "replication/logicallauncher.h" #include "replication/slot.h" @@ -33,6 +34,7 @@ #include "file_cache.h" #include "neon.h" #include "neon_lwlsncache.h" +#include "neon_perf_counters.h" #include "control_plane_connector.h" #include "logical_replication_monitor.h" #include "unstable_extensions.h" @@ -46,6 +48,13 @@ void _PG_init(void); static int running_xacts_overflow_policy; +static bool monitor_query_exec_time = false; + +static ExecutorStart_hook_type prev_ExecutorStart = NULL; +static ExecutorEnd_hook_type prev_ExecutorEnd = NULL; + +static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags); +static void neon_ExecutorEnd(QueryDesc *queryDesc); #if PG_MAJORVERSION_NUM >= 16 static shmem_startup_hook_type prev_shmem_startup_hook; @@ -470,6 +479,16 @@ _PG_init(void) 0, NULL, NULL, NULL); + DefineCustomBoolVariable( + "neon.monitor_query_exec_time", + "Collect infortmation about query execution time", + NULL, + &monitor_query_exec_time, + false, + PGC_USERSET, + 0, + NULL, NULL, NULL); + DefineCustomBoolVariable( "neon.allow_replica_misconfig", "Allow replica startup when some critical GUCs have smaller value than on primary node", @@ -508,6 +527,11 @@ _PG_init(void) EmitWarningsOnPlaceholders("neon"); ReportSearchPath(); + + prev_ExecutorStart = ExecutorStart_hook; + ExecutorStart_hook = neon_ExecutorStart; + prev_ExecutorEnd = ExecutorEnd_hook; + ExecutorEnd_hook = neon_ExecutorEnd; } PG_FUNCTION_INFO_V1(pg_cluster_size); @@ -581,3 +605,55 @@ neon_shmem_startup_hook(void) #endif } #endif + +/* + * ExecutorStart hook: start up tracking if needed + */ +static void +neon_ExecutorStart(QueryDesc *queryDesc, int eflags) +{ + if (prev_ExecutorStart) + prev_ExecutorStart(queryDesc, eflags); + else + standard_ExecutorStart(queryDesc, eflags); + + if (monitor_query_exec_time) + { + /* + * Set up to track total elapsed time in ExecutorRun. Make sure the + * space is allocated in the per-query context so it will go away at + * ExecutorEnd. + */ + if (queryDesc->totaltime == NULL) + { + MemoryContext oldcxt; + + oldcxt = MemoryContextSwitchTo(queryDesc->estate->es_query_cxt); + queryDesc->totaltime = InstrAlloc(1, INSTRUMENT_TIMER, false); + MemoryContextSwitchTo(oldcxt); + } + } +} + +/* + * ExecutorEnd hook: store results if needed + */ +static void +neon_ExecutorEnd(QueryDesc *queryDesc) +{ + if (monitor_query_exec_time && queryDesc->totaltime) + { + /* + * Make sure stats accumulation is done. (Note: it's okay if several + * levels of hook all do this.) + */ + InstrEndLoop(queryDesc->totaltime); + + inc_query_time(queryDesc->totaltime->total*1000000); /* convert to usec */ + } + + if (prev_ExecutorEnd) + prev_ExecutorEnd(queryDesc); + else + standard_ExecutorEnd(queryDesc); +} diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index c77d99d636..d0a3d15108 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -71,6 +71,27 @@ inc_iohist(IOHistogram hist, uint64 latency_us) hist->wait_us_count++; } +static inline void +inc_qthist(QTHistogram hist, uint64 elapsed_us) +{ + int lo = 0; + int hi = NUM_QT_BUCKETS - 1; + + /* Find the right bucket with binary search */ + while (lo < hi) + { + int mid = (lo + hi) / 2; + + if (elapsed_us < qt_bucket_thresholds[mid]) + hi = mid; + else + lo = mid + 1; + } + hist->elapsed_us_bucket[lo]++; + hist->elapsed_us_sum += elapsed_us; + hist->elapsed_us_count++; +} + /* * Count a GetPage wait operation. */ @@ -98,6 +119,13 @@ inc_page_cache_write_wait(uint64 latency) inc_iohist(&MyNeonCounters->file_cache_write_hist, latency); } + +void +inc_query_time(uint64 elapsed) +{ + inc_qthist(&MyNeonCounters->query_time_hist, elapsed); +} + /* * Support functions for the views, neon_backend_perf_counters and * neon_perf_counters. @@ -112,11 +140,11 @@ typedef struct } metric_t; static int -histogram_to_metrics(IOHistogram histogram, - metric_t *metrics, - const char *count, - const char *sum, - const char *bucket) +io_histogram_to_metrics(IOHistogram histogram, + metric_t *metrics, + const char *count, + const char *sum, + const char *bucket) { int i = 0; uint64 bucket_accum = 0; @@ -145,10 +173,44 @@ histogram_to_metrics(IOHistogram histogram, return i; } +static int +qt_histogram_to_metrics(QTHistogram histogram, + metric_t *metrics, + const char *count, + const char *sum, + const char *bucket) +{ + int i = 0; + uint64 bucket_accum = 0; + + metrics[i].name = count; + metrics[i].is_bucket = false; + metrics[i].value = (double) histogram->elapsed_us_count; + i++; + metrics[i].name = sum; + metrics[i].is_bucket = false; + metrics[i].value = (double) histogram->elapsed_us_sum / 1000000.0; + i++; + for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++) + { + uint64 threshold = qt_bucket_thresholds[bucketno]; + + bucket_accum += histogram->elapsed_us_bucket[bucketno]; + + metrics[i].name = bucket; + metrics[i].is_bucket = true; + metrics[i].bucket_le = (threshold == UINT64_MAX) ? INFINITY : ((double) threshold) / 1000000.0; + metrics[i].value = (double) bucket_accum; + i++; + } + + return i; +} + static metric_t * neon_perf_counters_to_metrics(neon_per_backend_counters *counters) { -#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + 12) +#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + (2 + NUM_QT_BUCKETS) + 12) metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t)); int i = 0; @@ -159,10 +221,10 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters) i++; \ } while (false) - i += histogram_to_metrics(&counters->getpage_hist, &metrics[i], - "getpage_wait_seconds_count", - "getpage_wait_seconds_sum", - "getpage_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->getpage_hist, &metrics[i], + "getpage_wait_seconds_count", + "getpage_wait_seconds_sum", + "getpage_wait_seconds_bucket"); APPEND_METRIC(getpage_prefetch_requests_total); APPEND_METRIC(getpage_sync_requests_total); @@ -178,14 +240,19 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters) APPEND_METRIC(file_cache_hits_total); - i += histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i], - "file_cache_read_wait_seconds_count", - "file_cache_read_wait_seconds_sum", - "file_cache_read_wait_seconds_bucket"); - i += histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i], - "file_cache_write_wait_seconds_count", - "file_cache_write_wait_seconds_sum", - "file_cache_write_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i], + "file_cache_read_wait_seconds_count", + "file_cache_read_wait_seconds_sum", + "file_cache_read_wait_seconds_bucket"); + i += io_histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i], + "file_cache_write_wait_seconds_count", + "file_cache_write_wait_seconds_sum", + "file_cache_write_wait_seconds_bucket"); + + i += qt_histogram_to_metrics(&counters->query_time_hist, &metrics[i], + "query_time_seconds_count", + "query_time_seconds_sum", + "query_time_seconds_bucket"); Assert(i == NUM_METRICS); @@ -257,7 +324,7 @@ neon_get_backend_perf_counters(PG_FUNCTION_ARGS) } static inline void -histogram_merge_into(IOHistogram into, IOHistogram from) +io_histogram_merge_into(IOHistogram into, IOHistogram from) { into->wait_us_count += from->wait_us_count; into->wait_us_sum += from->wait_us_sum; @@ -265,6 +332,15 @@ histogram_merge_into(IOHistogram into, IOHistogram from) into->wait_us_bucket[bucketno] += from->wait_us_bucket[bucketno]; } +static inline void +qt_histogram_merge_into(QTHistogram into, QTHistogram from) +{ + into->elapsed_us_count += from->elapsed_us_count; + into->elapsed_us_sum += from->elapsed_us_sum; + for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++) + into->elapsed_us_bucket[bucketno] += from->elapsed_us_bucket[bucketno]; +} + PG_FUNCTION_INFO_V1(neon_get_perf_counters); Datum neon_get_perf_counters(PG_FUNCTION_ARGS) @@ -283,7 +359,7 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) { neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno]; - histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist); + io_histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist); totals.getpage_prefetch_requests_total += counters->getpage_prefetch_requests_total; totals.getpage_sync_requests_total += counters->getpage_sync_requests_total; totals.getpage_prefetch_misses_total += counters->getpage_prefetch_misses_total; @@ -294,13 +370,13 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) totals.pageserver_open_requests += counters->pageserver_open_requests; totals.getpage_prefetches_buffered += counters->getpage_prefetches_buffered; totals.file_cache_hits_total += counters->file_cache_hits_total; - histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist); - histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist); - totals.compute_getpage_stuck_requests_total += counters->compute_getpage_stuck_requests_total; totals.compute_getpage_max_inflight_stuck_time_ms = Max( totals.compute_getpage_max_inflight_stuck_time_ms, counters->compute_getpage_max_inflight_stuck_time_ms); + io_histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist); + io_histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist); + qt_histogram_merge_into(&totals.query_time_hist, &counters->query_time_hist); } metrics = neon_perf_counters_to_metrics(&totals); diff --git a/pgxn/neon/neon_perf_counters.h b/pgxn/neon/neon_perf_counters.h index 10cf094d4a..4b611b0636 100644 --- a/pgxn/neon/neon_perf_counters.h +++ b/pgxn/neon/neon_perf_counters.h @@ -36,6 +36,28 @@ typedef struct IOHistogramData typedef IOHistogramData *IOHistogram; +static const uint64 qt_bucket_thresholds[] = { + 2, 3, 6, 10, /* 0 us - 10 us */ + 20, 30, 60, 100, /* 10 us - 100 us */ + 200, 300, 600, 1000, /* 100 us - 1 ms */ + 2000, 3000, 6000, 10000, /* 1 ms - 10 ms */ + 20000, 30000, 60000, 100000, /* 10 ms - 100 ms */ + 200000, 300000, 600000, 1000000, /* 100 ms - 1 s */ + 2000000, 3000000, 6000000, 10000000, /* 1 s - 10 s */ + 20000000, 30000000, 60000000, 100000000, /* 10 s - 100 s */ + UINT64_MAX, +}; +#define NUM_QT_BUCKETS (lengthof(qt_bucket_thresholds)) + +typedef struct QTHistogramData +{ + uint64 elapsed_us_count; + uint64 elapsed_us_sum; + uint64 elapsed_us_bucket[NUM_QT_BUCKETS]; +} QTHistogramData; + +typedef QTHistogramData *QTHistogram; + typedef struct { /* @@ -127,6 +149,11 @@ typedef struct /* LFC I/O time buckets */ IOHistogramData file_cache_read_hist; IOHistogramData file_cache_write_hist; + + /* + * Histogram of query execution time. + */ + QTHistogramData query_time_hist; } neon_per_backend_counters; /* Pointer to the shared memory array of neon_per_backend_counters structs */ @@ -149,6 +176,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared; extern void inc_getpage_wait(uint64 latency); extern void inc_page_cache_read_wait(uint64 latency); extern void inc_page_cache_write_wait(uint64 latency); +extern void inc_query_time(uint64 elapsed); extern Size NeonPerfCountersShmemSize(void); extern void NeonPerfCountersShmemInit(void); diff --git a/pgxn/neon/neon_pgversioncompat.c b/pgxn/neon/neon_pgversioncompat.c index 7c404fb5a9..6f57b618da 100644 --- a/pgxn/neon/neon_pgversioncompat.c +++ b/pgxn/neon/neon_pgversioncompat.c @@ -5,6 +5,7 @@ #include "funcapi.h" #include "miscadmin.h" +#include "access/xlog.h" #include "utils/tuplestore.h" #include "neon_pgversioncompat.h" @@ -41,5 +42,12 @@ InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags) rsinfo->setDesc = stored_tupdesc; MemoryContextSwitchTo(old_context); } + +TimeLineID GetWALInsertionTimeLine(void) +{ + return ThisTimeLineID + 1; +} + + #endif diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index bf91a02b45..787bd552f8 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -162,6 +162,7 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, #if PG_MAJORVERSION_NUM < 15 extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags); +extern TimeLineID GetWALInsertionTimeLine(void); #endif #endif /* NEON_PGVERSIONCOMPAT_H */ diff --git a/pgxn/neon/neon_walreader.c b/pgxn/neon/neon_walreader.c index d5e3a38dbb..0a1f6d9c72 100644 --- a/pgxn/neon/neon_walreader.c +++ b/pgxn/neon/neon_walreader.c @@ -69,6 +69,7 @@ struct NeonWALReader WALSegmentContext segcxt; WALOpenSegment seg; int wre_errno; + TimeLineID local_active_tlid; /* Explains failure to read, static for simplicity. */ char err_msg[NEON_WALREADER_ERR_MSG_LEN]; @@ -106,7 +107,7 @@ struct NeonWALReader /* palloc and initialize NeonWALReader */ NeonWALReader * -NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix) +NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid) { NeonWALReader *reader; @@ -118,6 +119,7 @@ NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_ MemoryContextAllocZero(TopMemoryContext, sizeof(NeonWALReader)); reader->available_lsn = available_lsn; + reader->local_active_tlid = tlid; reader->seg.ws_file = -1; reader->seg.ws_segno = 0; reader->seg.ws_tli = 0; @@ -577,6 +579,17 @@ NeonWALReaderIsRemConnEstablished(NeonWALReader *state) return state->rem_state == RS_ESTABLISHED; } +/* + * Whether remote connection is established. Once this is done, until successful + * local read or error socket is stable and user can update socket events + * instead of readding it each time. + */ +TimeLineID +NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state) +{ + return state->local_active_tlid; +} + /* * Returns events user should wait on connection socket or 0 if remote * connection is not active. diff --git a/pgxn/neon/neon_walreader.h b/pgxn/neon/neon_walreader.h index 3e41825069..722bc10537 100644 --- a/pgxn/neon/neon_walreader.h +++ b/pgxn/neon/neon_walreader.h @@ -19,9 +19,10 @@ typedef enum NEON_WALREAD_ERROR, } NeonWALReadResult; -extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix); +extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid); extern void NeonWALReaderFree(NeonWALReader *state); extern void NeonWALReaderResetRemote(NeonWALReader *state); +extern TimeLineID NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state); extern NeonWALReadResult NeonWALRead(NeonWALReader *state, char *buf, XLogRecPtr startptr, Size count, TimeLineID tli); extern pgsocket NeonWALReaderSocket(NeonWALReader *state); extern uint32 NeonWALReaderEvents(NeonWALReader *state); diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index f42103c7cd..91d39345e2 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -98,6 +98,7 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) wp = palloc0(sizeof(WalProposer)); wp->config = config; wp->api = api; + wp->localTimeLineID = config->pgTimeline; wp->state = WPS_COLLECTING_TERMS; wp->mconf.generation = INVALID_GENERATION; wp->mconf.members.len = 0; @@ -119,6 +120,10 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) { wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m"); } + if (*endptr != ':') + { + wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation"); + } /* Skip past : to the first hostname. */ host = endptr + 1; } @@ -1380,7 +1385,7 @@ ProcessPropStartPos(WalProposer *wp) * we must bail out, as clog and other non rel data is inconsistent. */ walprop_shared = wp->api.get_shmem_state(wp); - if (!wp->config->syncSafekeepers) + if (!wp->config->syncSafekeepers && !walprop_shared->replica_promote) { /* * Basebackup LSN always points to the beginning of the record (not diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index cca20e746b..08087e5a55 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -391,6 +391,7 @@ typedef struct WalproposerShmemState /* last feedback from each shard */ PageserverFeedback shard_ps_feedback[MAX_SHARDS]; int num_shards; + bool replica_promote; /* aggregated feedback with min LSNs across shards */ PageserverFeedback min_ps_feedback; @@ -806,6 +807,9 @@ typedef struct WalProposer /* Safekeepers walproposer is connecting to. */ Safekeeper safekeeper[MAX_SAFEKEEPERS]; + /* Current local TimeLineId in use */ + TimeLineID localTimeLineID; + /* WAL has been generated up to this point */ XLogRecPtr availableLsn; diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index d15bf91d24..3d6a92ad79 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -35,6 +35,7 @@ #include "storage/proc.h" #include "storage/ipc.h" #include "storage/lwlock.h" +#include "storage/pg_shmem.h" #include "storage/shmem.h" #include "storage/spin.h" #include "tcop/tcopprot.h" @@ -159,12 +160,19 @@ WalProposerMain(Datum main_arg) { WalProposer *wp; + if (*wal_acceptors_list == '\0') + { + wpg_log(WARNING, "Safekeepers list is empty"); + return; + } + init_walprop_config(false); walprop_pg_init_bgworker(); am_walproposer = true; walprop_pg_load_libpqwalreceiver(); wp = WalProposerCreate(&walprop_config, walprop_pg); + wp->localTimeLineID = GetWALInsertionTimeLine(); wp->last_reconnect_attempt = walprop_pg_get_current_timestamp(wp); walprop_pg_init_walsender(); @@ -272,6 +280,30 @@ split_safekeepers_list(char *safekeepers_list, char *safekeepers[]) return n_safekeepers; } +static char *split_off_safekeepers_generation(char *safekeepers_list, uint32 *generation) +{ + char *endptr; + + if (strncmp(safekeepers_list, "g#", 2) != 0) + { + return safekeepers_list; + } + else + { + errno = 0; + *generation = strtoul(safekeepers_list + 2, &endptr, 10); + if (errno != 0) + { + wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m"); + } + if (*endptr != ':') + { + wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation"); + } + return endptr + 1; + } +} + /* * Accept two coma-separated strings with list of safekeeper host:port addresses. * Split them into arrays and return false if two sets do not match, ignoring the order. @@ -283,6 +315,16 @@ safekeepers_cmp(char *old, char *new) char *safekeepers_new[MAX_SAFEKEEPERS]; int len_old = 0; int len_new = 0; + uint32 gen_old = INVALID_GENERATION; + uint32 gen_new = INVALID_GENERATION; + + old = split_off_safekeepers_generation(old, &gen_old); + new = split_off_safekeepers_generation(new, &gen_new); + + if (gen_old != gen_new) + { + return false; + } len_old = split_safekeepers_list(old, safekeepers_old); len_new = split_safekeepers_list(new, safekeepers_new); @@ -316,6 +358,9 @@ assign_neon_safekeepers(const char *newval, void *extra) char *newval_copy; char *oldval; + if (newval && *newval != '\0' && UsedShmemSegAddr && walprop_shared && RecoveryInProgress()) + walprop_shared->replica_promote = true; + if (!am_walproposer) return; @@ -506,16 +551,15 @@ BackpressureThrottlingTime(void) /* * Register a background worker proposing WAL to wal acceptors. + * We start walproposer bgworker even for replicas in order to support possible replica promotion. + * When pg_promote() function is called, then walproposer bgworker registered with BgWorkerStart_RecoveryFinished + * is automatically launched when promotion is completed. */ static void walprop_register_bgworker(void) { BackgroundWorker bgw; - /* If no wal acceptors are specified, don't start the background worker. */ - if (*wal_acceptors_list == '\0') - return; - memset(&bgw, 0, sizeof(bgw)); bgw.bgw_flags = BGWORKER_SHMEM_ACCESS; bgw.bgw_start_time = BgWorkerStart_RecoveryFinished; @@ -1292,9 +1336,7 @@ StartProposerReplication(WalProposer *wp, StartReplicationCmd *cmd) #if PG_VERSION_NUM < 150000 if (ThisTimeLineID == 0) - ereport(ERROR, - (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("IDENTIFY_SYSTEM has not been run before START_REPLICATION"))); + ThisTimeLineID = 1; #endif /* @@ -1508,7 +1550,7 @@ walprop_pg_wal_reader_allocate(Safekeeper *sk) snprintf(log_prefix, sizeof(log_prefix), WP_LOG_PREFIX "sk %s:%s nwr: ", sk->host, sk->port); Assert(!sk->xlogreader); - sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix); + sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix, sk->wp->localTimeLineID); if (sk->xlogreader == NULL) wpg_log(FATAL, "failed to allocate xlog reader"); } @@ -1522,7 +1564,7 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count, buf, startptr, count, - walprop_pg_get_timeline_id()); + sk->wp->localTimeLineID); if (res == NEON_WALREAD_SUCCESS) { diff --git a/pgxn/neon/walsender_hooks.c b/pgxn/neon/walsender_hooks.c index 81198d6c8d..534bf1c19b 100644 --- a/pgxn/neon/walsender_hooks.c +++ b/pgxn/neon/walsender_hooks.c @@ -111,7 +111,7 @@ NeonWALPageRead( readBuf, targetPagePtr, count, - walprop_pg_get_timeline_id()); + NeonWALReaderLocalActiveTimeLineID(wal_reader)); if (res == NEON_WALREAD_SUCCESS) { @@ -202,7 +202,7 @@ NeonOnDemandXLogReaderRoutines(XLogReaderRoutine *xlr) { elog(ERROR, "unable to start walsender when basebackupLsn is 0"); } - wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] "); + wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] ", 1); } xlr->page_read = NeonWALPageRead; xlr->segment_open = NeonWALReadSegmentOpen; diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,23 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { - warn!(?error, "error processing scram messages"); - error - }) - } + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), ) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..c388848926 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,8 +15,9 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pglb::connect_compute::ComputeConnectBackend; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; use crate::types::RoleName; use crate::{auth, compute, waiters}; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..f978f655c4 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,37 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; -use crate::protocol2::ConnectionInfoExtra; +use crate::pglb::connect_compute::ComputeConnectBackend; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -200,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -284,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -300,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -368,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -402,7 +290,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -438,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -447,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -474,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -540,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -553,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -577,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -635,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -654,55 +527,10 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +612,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +666,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -887,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..757c1e988b 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -232,8 +221,7 @@ struct ProxyCliArgs { is_private_access_proxy: bool, /// Configure whether all incoming requests have a Proxy Protocol V2 packet. - // TODO(conradludgate): switch default to rejected or required once we've updated all deployments - #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)] + #[clap(value_enum, long, default_value_t = ProxyProtocolV2::Rejected)] proxy_protocol_v2: ProxyProtocolV2, /// Time the proxy waits for the webauth session to be confirmed by the control plane. @@ -410,22 +398,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -476,8 +451,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +459,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +467,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); @@ -681,9 +653,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -13,15 +12,15 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; -use crate::protocol2::ConnectionInfoExtra; +use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..248584a19a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -40,8 +39,6 @@ pub struct ComputeConfig { pub enum ProxyProtocolV2 { /// Connection will error if PROXY protocol v2 header is missing Required, - /// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing. - Supported, /// Connection will error if PROXY protocol v2 header is provided Rejected, } @@ -65,9 +62,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..f2484b54b8 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -11,10 +11,10 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; +use crate::pglb::handshake::{HandshakeData, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::handshake::{HandshakeData, handshake}; -use crate::proxy::passthrough::ProxyPassthrough; use crate::proxy::{ ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled, }; @@ -54,30 +54,24 @@ pub async fn task_main( debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); connections.spawn(async move { - let (socket, peer_addr) = match read_proxy_protocol(socket).await { - Err(e) => { - error!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + error!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - error!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - error!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -86,7 +80,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( @@ -98,7 +92,7 @@ pub async fn task_main( let ctx = RequestContext::new( session_id, - peer_addr, + conn_info, crate::metrics::Protocol::Tcp, &config.region, ); @@ -159,7 +153,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, @@ -221,12 +215,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +230,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +238,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -15,7 +17,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +25,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,66 +67,41 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) - .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), - ]) - .build()?; + let response = { + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response }; - info!(duration = ?start.elapsed(), "received http response"); - let body = match parse_body::(response).await { + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -180,7 +156,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -216,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -268,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -313,225 +292,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } @@ -640,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/pglb/connect_compute.rs similarity index 98% rename from proxy/src/proxy/connect_compute.rs rename to proxy/src/pglb/connect_compute.rs index e013fbbe2e..1d6ca5fbb3 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/pglb/connect_compute.rs @@ -1,9 +1,7 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; -use super::retry::ShouldRetryWakeCompute; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection}; use crate::config::{ComputeConfig, RetryConfig}; @@ -15,7 +13,8 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/pglb/copy_bidirectional.rs similarity index 100% rename from proxy/src/proxy/copy_bidirectional.rs rename to proxy/src/pglb/copy_bidirectional.rs diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/pglb/handshake.rs similarity index 76% rename from proxy/src/proxy/handshake.rs rename to proxy/src/pglb/handshake.rs index 54c02f2c15..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/pglb/handshake.rs @@ -1,8 +1,4 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +8,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -59,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -71,33 +70,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +96,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -119,11 +109,12 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +148,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: tls_stream, + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +168,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +178,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +200,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +222,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 1088859fb9..4b107142a7 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -1 +1,5 @@ +pub mod connect_compute; +pub mod copy_bidirectional; +pub mod handshake; pub mod inprocess; +pub mod passthrough; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/pglb/passthrough.rs similarity index 96% rename from proxy/src/proxy/passthrough.rs rename to proxy/src/pglb/passthrough.rs index 8f9bd2de2d..6f651d383d 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -52,7 +53,7 @@ pub(crate) async fn proxy_pass( // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); - let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( + let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute( &mut client, &mut compute, ) @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..43074bf208 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: u32, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = header.len.get().checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len as usize, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 0793998639..5bec6d6ca3 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -4,60 +4,13 @@ use core::fmt; use std::io; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use bytes::{Buf, Bytes, BytesMut}; -use pin_project_lite::pin_project; +use bytes::Buf; use smol_str::SmolStr; use strum_macros::FromRepr; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt}; use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian}; -pin_project! { - /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough - pub(crate) struct ChainRW { - #[pin] - pub(crate) inner: T, - buf: BytesMut, - } -} - -impl AsyncWrite for ChainRW { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_shutdown(cx) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -} - /// Proxy Protocol Version 2 Header const SIGNATURE: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, @@ -79,7 +32,6 @@ pub struct ConnectionInfo { #[derive(PartialEq, Eq, Clone, Debug)] pub enum ConnectHeader { - Missing, Local, Proxy(ConnectionInfo), } @@ -106,47 +58,24 @@ pub enum ConnectionInfoExtra { pub(crate) async fn read_proxy_protocol( mut read: T, -) -> std::io::Result<(ChainRW, ConnectHeader)> { - let mut buf = BytesMut::with_capacity(128); - let header = loop { - let bytes_read = read.read_buf(&mut buf).await?; - - // exit for bad header signature - let len = usize::min(buf.len(), SIGNATURE.len()); - if buf[..len] != SIGNATURE[..len] { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // if no more bytes available then exit - if bytes_read == 0 { - return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing)); - } - - // check if we have enough bytes to continue - if let Some(header) = buf.try_get::() { - break header; - } - }; - - let remaining_length = usize::from(header.len.get()); - - while buf.len() < remaining_length { - if read.read_buf(&mut buf).await? == 0 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "stream closed while waiting for proxy protocol addresses", - )); - } +) -> std::io::Result<(T, ConnectHeader)> { + let mut header = [0; size_of::()]; + read.read_exact(&mut header).await?; + let header: ProxyProtocolV2Header = zerocopy::transmute!(header); + if header.signature != SIGNATURE { + return Err(std::io::Error::other("invalid proxy protocol header")); } - let payload = buf.split_to(remaining_length); - let res = process_proxy_payload(header, payload)?; - Ok((ChainRW { inner: read, buf }, res)) + let mut payload = vec![0; usize::from(header.len.get())]; + read.read_exact(&mut payload).await?; + + let res = process_proxy_payload(header, &payload)?; + Ok((read, res)) } fn process_proxy_payload( header: ProxyProtocolV2Header, - mut payload: BytesMut, + mut payload: &[u8], ) -> std::io::Result { match header.version_and_command { // the connection was established on purpose by the proxy @@ -162,13 +91,12 @@ fn process_proxy_payload( PROXY_V2 => {} // other values are unassigned and must not be emitted by senders. Receivers // must drop connections presenting unexpected values here. - #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384 - _ => return Err(io::Error::other( - format!( + _ => { + return Err(io::Error::other(format!( "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)", header.version_and_command - ), - )), + ))); + } } let size_err = @@ -206,7 +134,7 @@ fn process_proxy_payload( } let subtype = tlv.value.get_u8(); match Pp2AwsType::from_repr(subtype) { - Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) { + Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) { Ok(s) => { extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() }); } @@ -282,65 +210,28 @@ enum Pp2AzureType { PrivateEndpointLinkId = 0x01, } -impl AsyncRead for ChainRW { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.buf.is_empty() { - self.project().inner.poll_read(cx, buf) - } else { - self.read_from_buf(buf) - } - } -} - -impl ChainRW { - #[cold] - fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll> { - debug_assert!(!self.buf.is_empty()); - let this = self.project(); - - let write = usize::min(this.buf.len(), buf.remaining()); - let slice = this.buf.split_to(write).freeze(); - buf.put_slice(&slice); - - // reset the allocation so it can be freed - if this.buf.is_empty() { - *this.buf = BytesMut::new(); - } - - Poll::Ready(Ok(())) - } -} - #[derive(Debug)] -struct Tlv { +struct Tlv<'a> { kind: u8, - value: Bytes, + value: &'a [u8], } -fn read_tlv(b: &mut BytesMut) -> Option { +fn read_tlv<'a>(b: &mut &'a [u8]) -> Option> { let tlv_header = b.try_get::()?; let len = usize::from(tlv_header.len.get()); - if b.len() < len { - return None; - } Some(Tlv { kind: tlv_header.kind, - value: b.split_to(len).freeze(), + value: b.split_off(..len)?, }) } trait BufExt: Sized { fn try_get(&mut self) -> Option; } -impl BufExt for BytesMut { +impl BufExt for &[u8] { fn try_get(&mut self) -> Option { - let (res, _) = T::read_from_prefix(self).ok()?; - self.advance(size_of::()); + let (res, rest) = T::read_from_prefix(self).ok()?; + *self = rest; Some(res) } } @@ -481,27 +372,19 @@ mod tests { } #[tokio::test] + #[should_panic = "invalid proxy protocol header"] async fn test_invalid() { let data = [0x55; 256]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] + #[should_panic = "early eof"] async fn test_short() { let data = [0x55; 10]; - let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap(); - - let mut bytes = vec![]; - read.read_to_end(&mut bytes).await.unwrap(); - assert_eq!(bytes, data); - assert_eq!(info, ConnectHeader::Missing); + read_proxy_protocol(data.as_slice()).await.unwrap(); } #[tokio::test] diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..0e138cc0c7 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,36 +1,32 @@ #[cfg(test)] mod tests; -pub(crate) mod connect_compute; -mod copy_bidirectional; -pub(crate) mod handshake; -pub(crate) mod passthrough; pub(crate) mod retry; pub(crate) mod wake_compute; use std::sync::Arc; -pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; -use self::connect_compute::{TcpMechanism, connect_to_compute}; -use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; +pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; +use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; +use crate::pglb::passthrough::ProxyPassthrough; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; -use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; @@ -38,6 +34,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -90,30 +98,24 @@ pub async fn task_main( let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); connections.spawn(async move { - let (socket, conn_info) = match read_proxy_protocol(socket).await { - Err(e) => { - warn!("per-client task finished with an error: {e:#}"); - return; + let (socket, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(socket).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_socket, ConnectHeader::Local)) => { + debug!("healthcheck received"); + return; + } + Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), + } } - // our load balancers will not send any more data. let's just exit immediately - Ok((_socket, ConnectHeader::Local)) => { - debug!("healthcheck received"); - return; - } - Ok((_socket, ConnectHeader::Missing)) - if config.proxy_protocol_v2 == ProxyProtocolV2::Required => - { - warn!("missing required proxy protocol header"); - return; - } - Ok((_socket, ConnectHeader::Proxy(_))) - if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => - { - warn!("proxy protocol header not supported"); - return; - } - Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => ( + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( socket, ConnectionInfo { addr: peer_addr, @@ -122,7 +124,7 @@ pub async fn task_main( ), }; - match socket.inner.set_nodelay(true) { + match socket.set_nodelay(true) { Ok(()) => {} Err(e) => { error!( @@ -236,7 +238,7 @@ pub(crate) enum ClientRequestError { #[error("{0}")] Cancellation(#[from] cancellation::CancelError), #[error("{0}")] - Handshake(#[from] handshake::HandshakeError), + Handshake(#[from] HandshakeError), #[error("{0}")] HandshakeTimeout(#[from] tokio::time::error::Elapsed), #[error("{0}")] @@ -258,7 +260,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, @@ -329,11 +331,11 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -349,10 +351,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +367,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +379,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +412,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..e5db0013a7 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -17,7 +17,6 @@ use rustls::pki_types; use tokio::io::DuplexStream; use tracing_test::traced_test; -use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; use super::*; use crate::auth::backend::{ @@ -26,10 +25,9 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; +use crate::pglb::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; @@ -128,7 +126,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +155,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -177,7 +173,6 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let (client, _) = read_proxy_protocol(client).await?; let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? { HandshakeData::Startup(stream, _) => stream, HandshakeData::Cancel(_) => bail!("cancellation not supported"), @@ -185,10 +180,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } @@ -547,20 +544,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 9d8915e24a..06c2da58db 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,6 +1,5 @@ use tracing::{error, info}; -use super::connect_compute::ComputeConnectBackend; use crate::config::RetryConfig; use crate::context::RequestContext; use crate::control_plane::CachedNodeInfo; @@ -9,6 +8,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pglb::connect_compute::ComputeConnectBackend; use crate::proxy::retry::{retry_after, should_retry}; // Use macro to retain original callsite. diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -243,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,63 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let (mut mechanism, mut input) = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) + }; + + loop { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { + mechanism = moved_mechanism; + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + drop(reply); + } + Ok(Step::Success(result, reply)) => { + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..748e0ce6f2 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,8 +35,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; -use crate::proxy::connect_compute::ConnectMechanism; +use crate::pglb::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, @@ -222,7 +182,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys); - crate::proxy::connect_compute::connect_to_compute( + crate::pglb::connect_compute::connect_to_compute( ctx, &TokioMechanism { conn_id, @@ -266,7 +226,7 @@ impl PoolingBackend { }, keys: crate::auth::backend::ComputeCredentialKeys::None, }); - crate::proxy::connect_compute::connect_to_compute( + crate::pglb::connect_compute::connect_to_compute( ctx, &HyperMechanism { conn_id, diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 2a7069b1c2..f6f681ac45 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -49,7 +49,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; -use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol}; +use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; @@ -207,12 +207,12 @@ pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { - async fn accept(&self, conn: ChainRW) -> std::io::Result; + async fn accept(&self, conn: TcpStream) -> std::io::Result; } #[async_trait] impl MaybeTlsAcceptor for &'static ArcSwapOption { - async fn accept(&self, conn: ChainRW) -> std::io::Result { + async fn accept(&self, conn: TcpStream) -> std::io::Result { match &*self.load() { Some(config) => Ok(Box::pin( TlsAcceptor::from(config.http_config.clone()) @@ -235,33 +235,30 @@ async fn connection_startup( peer_addr: SocketAddr, ) -> Option<(AsyncRW, ConnectionInfo)> { // handle PROXY protocol - let (conn, peer) = match read_proxy_protocol(conn).await { - Ok(c) => c, - Err(e) => { - tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); - return None; + let (conn, conn_info) = match config.proxy_protocol_v2 { + ProxyProtocolV2::Required => { + match read_proxy_protocol(conn).await { + Err(e) => { + warn!("per-client task finished with an error: {e:#}"); + return None; + } + // our load balancers will not send any more data. let's just exit immediately + Ok((_conn, ConnectHeader::Local)) => { + tracing::debug!("healthcheck received"); + return None; + } + Ok((conn, ConnectHeader::Proxy(info))) => (conn, info), + } } - }; - - let conn_info = match peer { - // our load balancers will not send any more data. let's just exit immediately - ConnectHeader::Local => { - tracing::debug!("healthcheck received"); - return None; - } - ConnectHeader::Missing if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - tracing::warn!("missing required proxy protocol header"); - return None; - } - ConnectHeader::Proxy(_) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - tracing::warn!("proxy protocol header not supported"); - return None; - } - ConnectHeader::Proxy(info) => info, - ConnectHeader::Missing => ConnectionInfo { - addr: peer_addr, - extra: None, - }, + // ignore the header - it cannot be confused for a postgres or http connection so will + // error later. + ProxyProtocolV2::Rejected => ( + conn, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), }; let has_private_peer_addr = match conn_info.addr.ip() { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: u32 = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ab4885ce6b..db3f080261 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -357,31 +357,6 @@ class PgProtocol: return TimelineId(cast("str", self.safe_psql("show neon.timeline_id")[0][0])) -class PageserverWalReceiverProtocol(StrEnum): - VANILLA = "vanilla" - INTERPRETED = "interpreted" - - @staticmethod - def to_config_key_value(proto) -> tuple[str, dict[str, Any]]: - if proto == PageserverWalReceiverProtocol.VANILLA: - return ( - "wal_receiver_protocol", - { - "type": "vanilla", - }, - ) - elif proto == PageserverWalReceiverProtocol.INTERPRETED: - return ( - "wal_receiver_protocol", - { - "type": "interpreted", - "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, - }, - ) - else: - raise ValueError(f"Unknown protocol type: {proto}") - - @dataclass class PageserverTracingConfig: sampling_ratio: tuple[int, int] @@ -423,6 +398,7 @@ class PageserverImportConfig: "import_job_concurrency": self.import_job_concurrency, "import_job_soft_size_limit": self.import_job_soft_size_limit, "import_job_checkpoint_threshold": self.import_job_checkpoint_threshold, + "import_job_max_byte_range_size": 4 * 1024 * 1024, # Pageserver default } return ("timeline_import_config", value) @@ -474,7 +450,6 @@ class NeonEnvBuilder: safekeeper_extra_opts: list[str] | None = None, storage_controller_port_override: int | None = None, pageserver_virtual_file_io_mode: str | None = None, - pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, pageserver_import_config: PageserverImportConfig | None = None, @@ -551,11 +526,6 @@ class NeonEnvBuilder: self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode - if pageserver_wal_receiver_protocol is not None: - self.pageserver_wal_receiver_protocol = pageserver_wal_receiver_protocol - else: - self.pageserver_wal_receiver_protocol = PageserverWalReceiverProtocol.INTERPRETED - assert test_name.startswith("test_"), ( "Unexpectedly instantiated from outside a test function" ) @@ -1201,7 +1171,6 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode - self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io self.pageserver_tracing_config = config.pageserver_tracing_config if config.pageserver_import_config is None: @@ -1333,13 +1302,6 @@ class NeonEnv: for key, value in override.items(): ps_cfg[key] = value - if self.pageserver_wal_receiver_protocol is not None: - key, value = PageserverWalReceiverProtocol.to_config_key_value( - self.pageserver_wal_receiver_protocol - ) - if key not in ps_cfg: - ps_cfg[key] = value - if self.pageserver_tracing_config is not None: key, value = self.pageserver_tracing_config.to_config_key_value() @@ -4710,7 +4672,7 @@ class EndpointFactory: origin: Endpoint, endpoint_id: str | None = None, config_lines: list[str] | None = None, - ): + ) -> Endpoint: branch_name = origin.branch_name assert origin in self.endpoints assert branch_name is not None @@ -4729,7 +4691,7 @@ class EndpointFactory: origin: Endpoint, endpoint_id: str | None = None, config_lines: list[str] | None = None, - ): + ) -> Endpoint: branch_name = origin.branch_name assert origin in self.endpoints assert branch_name is not None diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index 293026d40a..364fcf3737 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -15,19 +15,10 @@ from fixtures.neon_fixtures import ( @pytest.mark.timeout(1200) @pytest.mark.parametrize("shard_count", [1, 8, 32]) -@pytest.mark.parametrize( - "wal_receiver_protocol", - [ - "vanilla", - "interpreted-bincode-compressed", - "interpreted-protobuf-compressed", - ], -) def test_sharded_ingest( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, shard_count: int, - wal_receiver_protocol: str, ): """ Benchmarks sharded ingestion throughput, by ingesting a large amount of WAL into a Safekeeper @@ -39,36 +30,6 @@ def test_sharded_ingest( neon_env_builder.num_pageservers = shard_count env = neon_env_builder.init_configs() - for ps in env.pageservers: - if wal_receiver_protocol == "vanilla": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "vanilla", - } - } - ) - elif wal_receiver_protocol == "interpreted-bincode-compressed": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - } - } - ) - elif wal_receiver_protocol == "interpreted-protobuf-compressed": - ps.patch_config_toml_nonrecursive( - { - "wal_receiver_protocol": { - "type": "interpreted", - "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, - } - } - ) - else: - raise AssertionError("Test must use explicit wal receiver protocol config") - env.start() # Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index 3eb6b7193c..dc44fc77db 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -182,10 +182,6 @@ def test_fully_custom_config(positive_env: NeonEnv): "lsn_lease_length": "1m", "lsn_lease_length_for_ts": "5s", "timeline_offloading": False, - "wal_receiver_protocol_override": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - }, "rel_size_v2_enabled": True, "relsize_snapshot_cache_capacity": 10000, "gc_compaction_enabled": True, diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index b083c394c7..2d42be4051 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -26,6 +26,10 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): ps = env.pageserver ps_http = ps.http_client() + storcon_managed_timelines = (env.storage_controller_config or {}).get( + "timelines_onto_safekeepers", False + ) + # 1. Check that we always hit the cache after compute restart. for i in range(3): ep.start() @@ -33,15 +37,26 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): def check_metrics(i=i): metrics = ps_http.get_metrics() - # Never miss. - # The first time compute_ctl sends `get_basebackup` with lsn=None, we do not cache such requests. - # All other requests should be a hit - assert ( - metrics.query_one( - "pageserver_basebackup_cache_read_total", {"result": "miss"} - ).value - == 0 - ) + if storcon_managed_timelines: + # We do not cache the initial basebackup yet, + # so the first compute startup should be a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 1 + ) + else: + # If the timeline is not initialized on safekeeprs, + # the compute_ctl sends `get_basebackup` with lsn=None for the first startup. + # We do not use cache for such requests, so it's niether a hit nor a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 0 + ) + # All but the first requests are hits. assert ( metrics.query_one("pageserver_basebackup_cache_read_total", {"result": "hit"}).value diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 370f57b19d..1570d40ae9 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -10,7 +10,6 @@ import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, - PageserverWalReceiverProtocol, generate_uploads_and_deletions, ) from fixtures.pageserver.http import PageserverApiException @@ -68,14 +67,9 @@ PREEMPT_GC_COMPACTION_TENANT_CONF = { @skip_in_debug_build("only run with release build") -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) @pytest.mark.timeout(900) def test_pageserver_compaction_smoke( neon_env_builder: NeonEnvBuilder, - wal_receiver_protocol: PageserverWalReceiverProtocol, ): """ This is a smoke test that compaction kicks in. The workload repeatedly churns @@ -85,8 +79,6 @@ def test_pageserver_compaction_smoke( observed bounds. """ - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - # Effectively disable the page cache to rely only on image layers # to shorten reads. neon_env_builder.pageserver_config_override = """ diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index 2cb2ee7b58..c751a3e7cc 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -466,8 +466,13 @@ def test_perf_counters(neon_simple_env: NeonEnv): # # 1.5 is the minimum version to contain these views. cur.execute("CREATE EXTENSION neon VERSION '1.5'") + cur.execute("set neon.monitor_query_exec_time = on") cur.execute("SELECT * FROM neon_perf_counters") cur.execute("SELECT * FROM neon_backend_perf_counters") + cur.execute( + "select value from neon_backend_perf_counters where metric='query_time_seconds_count' and pid=pg_backend_pid()" + ) + assert cur.fetchall()[0][0] == 2 def collect_metric( diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 6b9dcbba07..89ff873ca3 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -1,9 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from fixtures.log_helper import log from fixtures.neon_cli import WalCraft -from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol + +if TYPE_CHECKING: + from fixtures.neon_fixtures import NeonEnvBuilder # Restart nodes with WAL end having specially crafted shape, like last record # crossing segment boundary, to test decoding issues. @@ -19,17 +23,10 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol "wal_record_crossing_segment_followed_by_small_one", ], ) -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) def test_crafted_wal_end( neon_env_builder: NeonEnvBuilder, wal_type: str, - wal_receiver_protocol: PageserverWalReceiverProtocol, ): - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - env = neon_env_builder.init_start() env.create_branch("test_crafted_wal_end") env.pageserver.allowed_errors.extend( diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index 24ba0713d2..fe3b220c67 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -159,7 +159,8 @@ def test_remote_extensions( # Setup a mock nginx S3 gateway which will return our test extension. (host, port) = httpserver_listen_address - extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway" + remote_ext_base_url = f"http://{host}:{port}/pg-ext-s3-gateway" + log.info(f"remote extensions base URL: {remote_ext_base_url}") extension.build(pg_config, test_output_dir) tarball = extension.package(test_output_dir) @@ -221,7 +222,7 @@ def test_remote_extensions( endpoint.create_remote_extension_spec(spec) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) with endpoint.connect() as conn: with conn.cursor() as cur: @@ -249,7 +250,7 @@ def test_remote_extensions( # Remove the extension files to force a redownload of the extension. extension.remove(test_output_dir, pg_version) - endpoint.start(remote_ext_base_url=extensions_endpoint) + endpoint.start(remote_ext_base_url=remote_ext_base_url) # Test that ALTER EXTENSION UPDATE statements also fetch remote extensions. with endpoint.connect() as conn: diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index 4044f25b37..1ff61ce8dc 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -74,8 +74,9 @@ def test_hot_standby(neon_simple_env: NeonEnv): for query in queries: with s_con.cursor() as secondary_cursor: secondary_cursor.execute(query) - response = secondary_cursor.fetchone() - assert response is not None + res = secondary_cursor.fetchone() + assert res is not None + response = res assert response == responses[query] # Check for corrupted WAL messages which might otherwise go unnoticed if @@ -164,7 +165,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): s_cur.execute("SELECT COUNT(*) FROM test") res = s_cur.fetchone() - assert res[0] == 10000 + assert res == (10000,) # Clear the cache in the standby, so that when we # re-execute the query, it will make GetPage @@ -195,7 +196,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): s_cur.execute("SELECT COUNT(*) FROM test") log_replica_lag(primary, secondary) res = s_cur.fetchone() - assert res[0] == 10000 + assert res == (10000,) def run_pgbench(connstr: str, pg_bin: PgBin): diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index e5908de363..8d18311f3d 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, ".*downloading failed, possibly for shutdown", # {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n' ".*page_service.*will not become active.*", + # the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state + ".*wal_connection_manager.*layer file download failed: No file found.*", + ".*wal_connection_manager.*could not ingest record.*", ] ) diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py new file mode 100644 index 0000000000..e378d37635 --- /dev/null +++ b/test_runner/regress/test_replica_promotes.py @@ -0,0 +1,133 @@ +""" +File with secondary->primary promotion testing. + +This far, only contains a test that we don't break and that the data is persisted. +""" + +import psycopg2 +from fixtures.log_helper import log +from fixtures.neon_fixtures import Endpoint, NeonEnv, wait_replica_caughtup +from fixtures.pg_version import PgVersion +from pytest import raises + + +def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): + """ + Test that a replica safely promotes, and can commit data updates which + show up when the primary boots up after the promoted secondary endpoint + shut down. + """ + + # Initialize the primary, a test table, and a helper function to create lots + # of subtransactions. + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + + with primary.connect() as primary_conn: + primary_cur = primary_conn.cursor() + primary_cur.execute( + "create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)" + ) + primary_cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Primary: Current LSN after workload is {primary_cur.fetchone()}") + primary_cur.execute("show neon.safekeepers") + safekeepers = primary_cur.fetchall()[0][0] + + wait_replica_caughtup(primary, secondary) + + with secondary.connect() as secondary_conn: + secondary_cur = secondary_conn.cursor() + secondary_cur.execute("select count(*) from t") + + assert secondary_cur.fetchone() == (100,) + + with raises(psycopg2.Error): + secondary_cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200)") + secondary_conn.commit() + + secondary_conn.rollback() + secondary_cur.execute("select count(*) from t") + assert secondary_cur.fetchone() == (100,) + + primary.stop_and_destroy(mode="immediate") + + # Reconnect to the secondary to make sure we get a read-write connection + promo_conn = secondary.connect() + promo_cur = promo_conn.cursor() + promo_cur.execute(f"alter system set neon.safekeepers='{safekeepers}'") + promo_cur.execute("select pg_reload_conf()") + + promo_cur.execute("SELECT * FROM pg_promote()") + assert promo_cur.fetchone() == (True,) + promo_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Secondary: LSN after promotion is {promo_cur.fetchone()}") + + # Reconnect to the secondary to make sure we get a read-write connection + with secondary.connect() as new_primary_conn: + new_primary_cur = new_primary_conn.cursor() + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (100,) + + new_primary_cur.execute( + "INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload" + ) + assert new_primary_cur.fetchall() == [(it,) for it in range(101, 201)] + + new_primary_cur = new_primary_conn.cursor() + new_primary_cur.execute("select payload from t") + assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] + + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (200,) + new_primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"Secondary: LSN after workload is {new_primary_cur.fetchone()}") + + with secondary.connect() as second_viewpoint_conn: + new_primary_cur = second_viewpoint_conn.cursor() + new_primary_cur.execute("select payload from t") + assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] + + # wait_for_last_flush_lsn(env, secondary, env.initial_tenant, env.initial_timeline) + + secondary.stop_and_destroy() + + primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + + with primary.connect() as new_primary: + new_primary_cur = new_primary.cursor() + new_primary_cur.execute( + """ + SELECT pg_current_wal_insert_lsn(), + pg_current_wal_lsn(), + pg_current_wal_flush_lsn() + """ + ) + log.info(f"New primary: Boot LSN is {new_primary_cur.fetchone()}") + + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (200,) + new_primary_cur.execute("INSERT INTO t (payload) SELECT generate_series(201, 300)") + new_primary_cur.execute("select count(*) from t") + assert new_primary_cur.fetchone() == (300,) + + primary.stop(mode="immediate") diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index b235da0bc7..92a21007c8 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,9 +1,7 @@ from __future__ import annotations -import pytest from fixtures.neon_fixtures import ( NeonEnvBuilder, - PageserverWalReceiverProtocol, check_restored_datadir_content, ) @@ -14,13 +12,7 @@ from fixtures.neon_fixtures import ( # maintained in the pageserver, so subtransactions are not very exciting for # Neon. They are included in the commit record though and updated in the # CLOG. -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) -def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir, wal_receiver_protocol): - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol - +def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir): env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") diff --git a/test_runner/regress/test_tenant_conf.py b/test_runner/regress/test_tenant_conf.py index de6bdc0aec..d78b9d8817 100644 --- a/test_runner/regress/test_tenant_conf.py +++ b/test_runner/regress/test_tenant_conf.py @@ -348,7 +348,6 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st def assert_tenant_conf_semantically_equal(lhs, rhs): """ - Storcon returns None for fields that are not set while the pageserver does not. Compare two tenant's config overrides semantically, by dropping the None values. """ lhs = {k: v for k, v in lhs.items() if v is not None} @@ -375,10 +374,7 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st patch: dict[str, Any | None] = { "gc_period": "3h", - "wal_receiver_protocol_override": { - "type": "interpreted", - "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, - }, + "gc_compaction_ratio_percent": 10, } api.patch_tenant_config(env.initial_tenant, patch) tenant_conf_after_patch = api.tenant_config(env.initial_tenant).tenant_specific_overrides @@ -391,7 +387,7 @@ def test_tenant_config_patch(neon_env_builder: NeonEnvBuilder, ps_managed_by: st assert_tenant_conf_semantically_equal(tenant_conf_after_patch, crnt_tenant_conf | patch) crnt_tenant_conf = tenant_conf_after_patch - patch = {"gc_period": "5h", "wal_receiver_protocol_override": None} + patch = {"gc_period": "5h", "gc_compaction_ratio_percent": None} api.patch_tenant_config(env.initial_tenant, patch) tenant_conf_after_patch = api.tenant_config(env.initial_tenant).tenant_specific_overrides if ps_managed_by == "storcon": diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index 4070f99568..d8a7dc2a2b 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -14,7 +14,6 @@ from fixtures.neon_fixtures import ( Endpoint, NeonEnv, NeonEnvBuilder, - PageserverWalReceiverProtocol, Safekeeper, ) from fixtures.remote_storage import RemoteStorageKind @@ -751,15 +750,8 @@ async def run_segment_init_failure(env: NeonEnv): # Test (injected) failure during WAL segment init. # https://github.com/neondatabase/neon/issues/6401 # https://github.com/neondatabase/neon/issues/6402 -@pytest.mark.parametrize( - "wal_receiver_protocol", - [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], -) -def test_segment_init_failure( - neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol -): +def test_segment_init_failure(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 1 - neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() asyncio.run(run_segment_init_failure(env)) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 55c0d45abe..6770bc2513 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 55c0d45abe6467c02084c2192bca117eda6ce1e7 +Subproject commit 6770bc251301ef40c66f7ecb731741dc435b5051 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index de7640f55d..8c3249f36c 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit de7640f55da07512834d5cc40c4b3fb376b5f04f +Subproject commit 8c3249f36c7df6ac0efb8ee9f1baf4aa1b83e5c9 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 0bf96bd6d7..7a4c0eacae 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198 +Subproject commit 7a4c0eacaeb9b97416542fa19103061c166460b1 diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index 8be779fd3a..db424d42d7 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit 8be779fd3ab9e87206da96a7e4842ef1abf04f44 +Subproject commit db424d42d748f8ad91ac00e28db2c7f2efa42f7f diff --git a/vendor/revisions.json b/vendor/revisions.json index 3e999760f4..12d5499ddb 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.5", - "8be779fd3ab9e87206da96a7e4842ef1abf04f44" + "db424d42d748f8ad91ac00e28db2c7f2efa42f7f" ], "v16": [ "16.9", - "0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198" + "7a4c0eacaeb9b97416542fa19103061c166460b1" ], "v15": [ "15.13", - "de7640f55da07512834d5cc40c4b3fb376b5f04f" + "8c3249f36c7df6ac0efb8ee9f1baf4aa1b83e5c9" ], "v14": [ "14.18", - "55c0d45abe6467c02084c2192bca117eda6ce1e7" + "6770bc251301ef40c66f7ecb731741dc435b5051" ] }