diff --git a/Cargo.lock b/Cargo.lock index 58821b37e0..fe8732628f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,17 +253,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3" -[[package]] -name = "atomic_enum" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -698,40 +687,13 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" -dependencies = [ - "async-trait", - "axum-core 0.4.5", - "bytes", - "futures-util", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "itoa", - "matchit 0.7.3", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper 1.0.1", - "tower 0.5.2", - "tower-layer", - "tower-service", -] - [[package]] name = "axum" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ - "axum-core 0.5.0", + "axum-core", "base64 0.22.1", "bytes", "form_urlencoded", @@ -742,7 +704,7 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "itoa", - "matchit 0.8.4", + "matchit", "memchr", "mime", "percent-encoding", @@ -762,26 +724,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-core" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper 1.0.1", - "tower-layer", - "tower-service", -] - [[package]] name = "axum-core" version = "0.5.0" @@ -808,8 +750,8 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b" dependencies = [ - "axum 0.8.1", - "axum-core 0.5.0", + "axum", + "axum-core", "bytes", "futures-util", "headers", @@ -1144,25 +1086,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cbindgen" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff" -dependencies = [ - "clap", - "heck 0.4.1", - "indexmap 2.9.0", - "log", - "proc-macro2", - "quote", - "serde", - "serde_json", - "syn 2.0.100", - "tempfile", - "toml", -] - [[package]] name = "cc" version = "1.2.16" @@ -1289,7 +1212,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.100", @@ -1347,34 +1270,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "communicator" -version = "0.1.0" -dependencies = [ - "atomic_enum", - "axum 0.8.1", - "bytes", - "cbindgen", - "clashmap", - "http 1.1.0", - "libc", - "metrics", - "neon-shmem", - "nix 0.30.1", - "pageserver_client_grpc", - "pageserver_page_api", - "prometheus", - "prost 0.13.5", - "thiserror 1.0.69", - "tokio", - "tokio-pipe", - "tonic 0.12.3", - "tracing", - "tracing-subscriber", - "uring-common", - "utils", -] - [[package]] name = "compute_api" version = "0.1.0" @@ -1400,7 +1295,7 @@ dependencies = [ "aws-sdk-kms", "aws-sdk-s3", "aws-smithy-types", - "axum 0.8.1", + "axum", "axum-extra", "base64 0.13.1", "bytes", @@ -2041,7 +1936,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc" dependencies = [ "darling", "either", - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.100", @@ -2155,7 +2050,7 @@ name = "endpoint_storage" version = "0.0.1" dependencies = [ "anyhow", - "axum 0.8.1", + "axum", "axum-extra", "camino", "camino-tempfile", @@ -2829,12 +2724,6 @@ dependencies = [ "http 1.1.0", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -3726,12 +3615,6 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - [[package]] name = "matchit" version = "0.8.4" @@ -3777,7 +3660,7 @@ version = "0.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.100", @@ -3936,17 +3819,6 @@ dependencies = [ "workspace_hack", ] -[[package]] -name = "neonart" -version = "0.1.0" -dependencies = [ - "crossbeam-utils", - "rand 0.9.1", - "rand_distr 0.5.1", - "spin", - "tracing", -] - [[package]] name = "never-say-never" version = "6.6.666" @@ -4379,7 +4251,8 @@ name = "pagebench" version = "0.1.0" dependencies = [ "anyhow", - "axum 0.8.1", + "async-trait", + "axum", "camino", "clap", "futures", @@ -4399,6 +4272,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "tonic 0.13.1", "tracing", "utils", "workspace_hack", @@ -4454,6 +4328,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", + "http 1.1.0", "http-utils", "humantime", "humantime-serde", @@ -4480,6 +4355,7 @@ dependencies = [ "postgres_connection", "postgres_ffi", "postgres_initdb", + "posthog_client_lite", "pprof", "pq_proto", "procfs", @@ -4516,6 +4392,8 @@ dependencies = [ "tokio-util", "toml_edit", "tonic 0.13.1", + "tonic-reflection", + "tower 0.5.2", "tracing", "tracing-utils", "twox-hash", @@ -4632,7 +4510,10 @@ name = "pageserver_page_api" version = "0.1.0" dependencies = [ "bytes", + "pageserver_api", + "postgres_ffi", "prost 0.13.5", + "smallvec", "thiserror 1.0.69", "tonic 0.13.1", "tonic-build", @@ -5086,11 +4967,16 @@ name = "posthog_client_lite" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "reqwest", "serde", "serde_json", "sha2", "thiserror 1.0.69", + "tokio", + "tokio-util", + "tracing", + "tracing-utils", "workspace_hack", ] @@ -5270,7 +5156,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.5.0", + "heck", "itertools 0.12.1", "log", "multimap", @@ -5291,7 +5177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck 0.5.0", + "heck", "itertools 0.12.1", "log", "multimap", @@ -5951,7 +5837,7 @@ dependencies = [ "async-trait", "getrandom 0.2.11", "http 1.1.0", - "matchit 0.8.4", + "matchit", "opentelemetry", "reqwest", "reqwest-middleware", @@ -7126,7 +7012,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -7551,16 +7437,6 @@ dependencies = [ "syn 2.0.100", ] -[[package]] -name = "tokio-pipe" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f213a84bffbd61b8fa0ba8a044b4bbe35d471d0b518867181e82bd5c15542784" -dependencies = [ - "libc", - "tokio", -] - [[package]] name = "tokio-postgres" version = "0.7.10" @@ -7755,25 +7631,16 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ - "async-stream", "async-trait", - "axum 0.7.9", "base64 0.22.1", "bytes", - "h2 0.4.4", "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.6.0", - "hyper-timeout", - "hyper-util", "percent-encoding", "pin-project", "prost 0.13.5", - "socket2", - "tokio", "tokio-stream", - "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -7786,7 +7653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" dependencies = [ "async-trait", - "axum 0.8.1", + "axum", "base64 0.22.1", "bytes", "flate2", @@ -7825,6 +7692,19 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tonic-reflection" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9687bd5bfeafebdded2356950f278bba8226f0b32109537c4253406e09aafe1" +dependencies = [ + "prost 0.13.5", + "prost-types 0.13.3", + "tokio", + "tokio-stream", + "tonic 0.13.1", +] + [[package]] name = "tower" version = "0.4.13" @@ -7833,13 +7713,9 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand 0.8.5", - "slab", "tokio", - "tokio-util", "tower-layer", "tower-service", "tracing", @@ -8326,7 +8202,7 @@ name = "vm_monitor" version = "0.1.0" dependencies = [ "anyhow", - "axum 0.8.1", + "axum", "cgroups-rs", "clap", "futures", @@ -8819,6 +8695,8 @@ dependencies = [ "ahash", "anstream", "anyhow", + "axum", + "axum-core", "base64 0.13.1", "base64 0.21.7", "base64ct", @@ -8841,10 +8719,8 @@ dependencies = [ "fail", "form_urlencoded", "futures-channel", - "futures-core", "futures-executor", "futures-io", - "futures-task", "futures-util", "generic-array", "getrandom 0.2.11", @@ -8874,7 +8750,6 @@ dependencies = [ "once_cell", "p256 0.13.2", "parquet", - "percent-encoding", "prettyplease", "proc-macro2", "prost 0.13.5", diff --git a/Cargo.toml b/Cargo.toml index 06e5bb0f7c..4863afe142 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -205,6 +205,7 @@ tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "gzip", "prost", "router", "server", "tls-ring", "tls-native-roots"] } +tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } @@ -254,6 +255,7 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus ## Local libraries compute_api = { version = "0.1", path = "./libs/compute_api/" } consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } +desim = { version = "0.1", path = "./libs/desim" } endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" } http-utils = { version = "0.1", path = "./libs/http-utils/" } metrics = { version = "0.1", path = "./libs/metrics/" } @@ -269,19 +271,19 @@ postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } postgres_initdb = { path = "./libs/postgres_initdb" } +posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } safekeeper_client = { path = "./safekeeper/client" } -desim = { version = "0.1", path = "./libs/desim" } storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy. storage_controller_client = { path = "./storage_controller/client" } tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" } tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" } utils = { version = "0.1", path = "./libs/utils/" } vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" } -walproposer = { version = "0.1", path = "./libs/walproposer/" } wal_decoder = { version = "0.1", path = "./libs/wal_decoder" } +walproposer = { version = "0.1", path = "./libs/walproposer/" } ## Common library dependency workspace_hack = { version = "0.1", path = "./workspace_hack/" } diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 9d4c93e1cd..f97f04968e 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \ + cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \ + cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \ + cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \ + cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \ + cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \ + cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index f4a5593b71..2afdde0cfa 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -1180,14 +1180,14 @@ RUN cd exts/rag && \ RUN cd exts/rag_bge_small_en_v15 && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control RUN cd exts/rag_jina_reranker_v1_tiny_en && \ sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \ ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \ - REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \ + REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \ cargo pgrx install --release --features remote_onnx && \ echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control @@ -1847,7 +1847,7 @@ COPY docker-compose/ext-src/ /ext-src/ COPY --from=pg-build /postgres /postgres #COPY --from=postgis-src /ext-src/ /ext-src/ COPY --from=plv8-src /ext-src/ /ext-src/ -#COPY --from=h3-pg-src /ext-src/ /ext-src/ +COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src COPY --from=postgresql-unit-src /ext-src/ /ext-src/ COPY --from=pgvector-src /ext-src/ /ext-src/ COPY --from=pgjwt-src /ext-src/ /ext-src/ diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 20b5e567a8..db6835da61 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -57,21 +57,6 @@ use tracing::{error, info}; use url::Url; use utils::failpoint_support; -// Compatibility hack: if the control plane specified any remote-ext-config -// use the default value for extension storage proxy gateway. -// Remove this once the control plane is updated to pass the gateway URL -fn parse_remote_ext_base_url(arg: &str) -> Result { - const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str = - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"; - - Ok(if arg.starts_with("http") { - arg - } else { - FALLBACK_PG_EXT_GATEWAY_BASE_URL - } - .to_owned()) -} - #[derive(Parser)] #[command(rename_all = "kebab-case")] struct Cli { @@ -79,9 +64,8 @@ struct Cli { pub pgbin: String, /// The base URL for the remote extension storage proxy gateway. - /// Should be in the form of `http(s)://[:]`. - #[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")] - pub remote_ext_base_url: Option, + #[arg(short = 'r', long)] + pub remote_ext_base_url: Option, /// The port to bind the external listening HTTP server to. Clients running /// outside the compute will talk to the compute through this port. Keep @@ -136,6 +120,10 @@ struct Cli { requires = "compute-id" )] pub control_plane_uri: Option, + + /// Interval in seconds for collecting installed extensions statistics + #[arg(long, default_value = "3600")] + pub installed_extensions_collection_interval: u64, } fn main() -> Result<()> { @@ -179,6 +167,7 @@ fn main() -> Result<()> { cgroup: cli.cgroup, #[cfg(target_os = "linux")] vm_monitor_addr: cli.vm_monitor_addr, + installed_extensions_collection_interval: cli.installed_extensions_collection_interval, }, config, )?; @@ -271,18 +260,4 @@ mod test { fn verify_cli() { Cli::command().debug_assert() } - - #[test] - fn parse_pg_ext_gateway_base_url() { - let arg = "http://pg-ext-s3-gateway2"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!(result, arg); - - let arg = "pg-ext-s3-gateway"; - let result = super::parse_remote_ext_base_url(arg).unwrap(); - assert_eq!( - result, - "http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local" - ); - } } diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 78acd78585..e65c210b23 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -339,6 +339,8 @@ async fn run_dump_restore( destination_connstring: String, ) -> Result<(), anyhow::Error> { let dumpdir = workdir.join("dumpdir"); + let num_jobs = num_cpus::get().to_string(); + info!("using {num_jobs} jobs for dump/restore"); let common_args = [ // schema mapping (prob suffices to specify them on one side) @@ -354,7 +356,7 @@ async fn run_dump_restore( "directory".to_string(), // concurrency "--jobs".to_string(), - num_cpus::get().to_string(), + num_jobs, // progress updates "--verbose".to_string(), ]; diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index f494e2444a..d678b7d670 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -31,6 +31,7 @@ use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::spawn; use tracing::{Instrument, debug, error, info, instrument, warn}; +use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; @@ -96,7 +97,10 @@ pub struct ComputeNodeParams { pub internal_http_port: u16, /// the address of extension storage proxy gateway - pub remote_ext_base_url: Option, + pub remote_ext_base_url: Option, + + /// Interval for installed extensions collection + pub installed_extensions_collection_interval: u64, } /// Compute node info shared across several `compute_ctl` threads. @@ -695,25 +699,18 @@ impl ComputeNode { let log_directory_path = Path::new(&self.params.pgdata).join("log"); let log_directory_path = log_directory_path.to_string_lossy().to_string(); - // Add project_id,endpoint_id tag to identify the logs. + // Add project_id,endpoint_id to identify the logs. // // These ids are passed from cplane, - // for backwards compatibility (old computes that don't have them), - // we set them to None. - // TODO: Clean up this code when all computes have them. - let tag: Option = match ( - pspec.spec.project_id.as_deref(), - pspec.spec.endpoint_id.as_deref(), - ) { - (Some(project_id), Some(endpoint_id)) => { - Some(format!("{project_id}/{endpoint_id}")) - } - (Some(project_id), None) => Some(format!("{project_id}/None")), - (None, Some(endpoint_id)) => Some(format!("None,{endpoint_id}")), - (None, None) => None, - }; + let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or(""); + let project_id = pspec.spec.project_id.as_deref().unwrap_or(""); - configure_audit_rsyslog(log_directory_path.clone(), tag, &remote_endpoint)?; + configure_audit_rsyslog( + log_directory_path.clone(), + endpoint_id, + project_id, + &remote_endpoint, + )?; // Launch a background task to clean up the audit logs launch_pgaudit_gc(log_directory_path); @@ -749,17 +746,7 @@ impl ComputeNode { let conf = self.get_tokio_conn_conf(None); tokio::task::spawn(async { - let res = get_installed_extensions(conf).await; - match res { - Ok(extensions) => { - info!( - "[NEON_EXT_STAT] {}", - serde_json::to_string(&extensions) - .expect("failed to serialize extensions list") - ); - } - Err(err) => error!("could not get installed extensions: {err:?}"), - } + let _ = installed_extensions(conf).await; }); } @@ -789,6 +776,9 @@ impl ComputeNode { // Log metrics so that we can search for slow operations in logs info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished"); + // Spawn the extension stats background task + self.spawn_extension_stats_task(); + if pspec.spec.prewarm_lfc_on_startup { self.prewarm_lfc(); } @@ -2199,6 +2189,41 @@ LIMIT 100", info!("Pageserver config changed"); } } + + pub fn spawn_extension_stats_task(&self) { + let conf = self.tokio_conn_conf.clone(); + let installed_extensions_collection_interval = + self.params.installed_extensions_collection_interval; + tokio::spawn(async move { + // An initial sleep is added to ensure that two collections don't happen at the same time. + // The first collection happens during compute startup. + tokio::time::sleep(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )) + .await; + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )); + loop { + interval.tick().await; + let _ = installed_extensions(conf.clone()).await; + } + }); + } +} + +pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { + let res = get_installed_extensions(conf).await; + match res { + Ok(extensions) => { + info!( + "[NEON_EXT_STAT] {}", + serde_json::to_string(&extensions).expect("failed to serialize extensions list") + ); + } + Err(err) => error!("could not get installed extensions: {err:?}"), + } + Ok(()) } pub fn forward_termination_signal() { diff --git a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf index 9ca7e36738..48b1a6f5c3 100644 --- a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf +++ b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf @@ -2,10 +2,24 @@ module(load="imfile") # Input configuration for log files in the specified directory -# Replace {log_directory} with the directory containing the log files -input(type="imfile" File="{log_directory}/*.log" Tag="{tag}" Severity="info" Facility="local0") +# The messages can be multiline. The start of the message is a timestamp +# in "%Y-%m-%d %H:%M:%S.%3N GMT" (so timezone hardcoded). +# Replace log_directory with the directory containing the log files +input(type="imfile" File="{log_directory}/*.log" + Tag="pgaudit_log" Severity="info" Facility="local5" + startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,") + # the directory to store rsyslog state files global(workDirectory="/var/log/rsyslog") -# Forward logs to remote syslog server -*.* @@{remote_endpoint} +# Construct json, endpoint_id and project_id as additional metadata +set $.json_log!endpoint_id = "{endpoint_id}"; +set $.json_log!project_id = "{project_id}"; +set $.json_log!msg = $msg; + +# Template suitable for rfc5424 syslog format +template(name="PgAuditLog" type="string" + string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%") + +# Forward to remote syslog receiver (@@:;format +local5.info @@{remote_endpoint};PgAuditLog diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3439383699..1857afa08c 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -83,6 +83,7 @@ use reqwest::StatusCode; use tar::Archive; use tracing::info; use tracing::log::warn; +use url::Url; use zstd::stream::read::Decoder; use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; @@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { pub async fn download_extension( ext_name: &str, ext_path: &RemotePath, - remote_ext_base_url: &str, + remote_ext_base_url: &Url, pgbin: &str, ) -> Result { info!("Download extension {:?} from {:?}", ext_name, ext_path); // TODO add retry logic let download_buffer = - match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await { + match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await { Ok(buffer) => buffer, Err(error_message) => { return Err(anyhow::anyhow!( diff --git a/compute_tools/src/rsyslog.rs b/compute_tools/src/rsyslog.rs index c873697623..3bc2e72b19 100644 --- a/compute_tools/src/rsyslog.rs +++ b/compute_tools/src/rsyslog.rs @@ -84,13 +84,15 @@ fn restart_rsyslog() -> Result<()> { pub fn configure_audit_rsyslog( log_directory: String, - tag: Option, + endpoint_id: &str, + project_id: &str, remote_endpoint: &str, ) -> Result<()> { let config_content: String = format!( include_str!("config_template/compute_audit_rsyslog_template.conf"), log_directory = log_directory, - tag = tag.unwrap_or("".to_string()), + endpoint_id = endpoint_id, + project_id = project_id, remote_endpoint = remote_endpoint ); diff --git a/control_plane/safekeepers.conf b/control_plane/safekeepers.conf index 576cc4a3a9..a73e274dfa 100644 --- a/control_plane/safekeepers.conf +++ b/control_plane/safekeepers.conf @@ -2,8 +2,10 @@ [pageserver] listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/simple.conf b/control_plane/simple.conf index 0ad90a4618..1eb21f846e 100644 --- a/control_plane/simple.conf +++ b/control_plane/simple.conf @@ -4,8 +4,10 @@ id=1 listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 98ab6e5657..ef6985d697 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -32,6 +32,7 @@ use control_plane::storage_controller::{ }; use nix::fcntl::{Flock, FlockArg}; use pageserver_api::config::{ + DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT, DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT, }; @@ -1007,13 +1008,16 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result { let pageserver_id = NodeId(DEFAULT_PAGESERVER_ID.0 + i as u64); let pg_port = DEFAULT_PAGESERVER_PG_PORT + i; let http_port = DEFAULT_PAGESERVER_HTTP_PORT + i; + let grpc_port = DEFAULT_PAGESERVER_GRPC_PORT + i; NeonLocalInitPageserverConf { id: pageserver_id, listen_pg_addr: format!("127.0.0.1:{pg_port}"), listen_http_addr: format!("127.0.0.1:{http_port}"), listen_https_addr: None, + listen_grpc_addr: Some(format!("127.0.0.1:{grpc_port}")), pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, other: Default::default(), // Typical developer machines use disks with slow fsync, and we don't care // about data integrity: disable disk syncs. @@ -1275,6 +1279,7 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re mode: pageserver_api::models::TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn: start_lsn, + read_only: false, pg_version: None, }, }; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 4a8892c6de..47b77f0720 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -278,8 +278,10 @@ pub struct PageServerConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub no_sync: bool, } @@ -290,8 +292,10 @@ impl Default for PageServerConf { listen_pg_addr: String::new(), listen_http_addr: String::new(), listen_https_addr: None, + listen_grpc_addr: None, pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, no_sync: false, } } @@ -306,8 +310,10 @@ pub struct NeonLocalInitPageserverConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub no_sync: bool, #[serde(flatten)] @@ -321,8 +327,10 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, other: _, } = conf; @@ -331,7 +339,9 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr: listen_pg_addr.clone(), listen_http_addr: listen_http_addr.clone(), listen_https_addr: listen_https_addr.clone(), + listen_grpc_addr: listen_grpc_addr.clone(), pg_auth_type: *pg_auth_type, + grpc_auth_type: *grpc_auth_type, http_auth_type: *http_auth_type, no_sync: *no_sync, } @@ -707,8 +717,10 @@ impl LocalEnv { listen_pg_addr: String, listen_http_addr: String, listen_https_addr: Option, + listen_grpc_addr: Option, pg_auth_type: AuthType, http_auth_type: AuthType, + grpc_auth_type: AuthType, #[serde(default)] no_sync: bool, } @@ -732,8 +744,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, } = config_toml; let IdentityTomlSubset { @@ -750,8 +764,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, }; pageservers.push(conf); diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 756f2b02db..29314dab9e 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -129,7 +129,9 @@ impl PageServerNode { )); } - if conf.http_auth_type != AuthType::Trust || conf.pg_auth_type != AuthType::Trust { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { // Keys are generated in the toplevel repo dir, pageservers' workdirs // are one level below that, so refer to keys with ../ overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned()); diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index 20a1ffb7a0..ab8d74d355 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \ | grep --invert-match ^$'\t' \ | cut --delimiter=: --fields=1 \ | head --lines=1)" -test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat. +test "$first_path" == '/usr/local/lib' echo "Waiting pageserver become ready." while ! nc -z pageserver 6400; do diff --git a/docker-compose/ext-src/h3-pg-src/neon-test.sh b/docker-compose/ext-src/h3-pg-src/neon-test.sh new file mode 100755 index 0000000000..e2ab22f03e --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/neon-test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -ex +cd "$(dirname "${0}")" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +dropdb --if-exists contrib_regression +createdb contrib_regression +cd h3_postgis/test +psql -d contrib_regression -c "CREATE EXTENSION postgis" -c "CREATE EXTENSION postgis_raster" -c "CREATE EXTENSION h3" -c "CREATE EXTENSION h3_postgis" +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} +cd ../../h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +dropdb --if-exists contrib_regression +createdb contrib_regression +psql -d contrib_regression -c "CREATE EXTENSION h3" +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} diff --git a/docker-compose/ext-src/h3-pg-src/test-upgrade.sh b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh new file mode 100755 index 0000000000..72d7040966 --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh @@ -0,0 +1,7 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +cd h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} \ No newline at end of file diff --git a/docker-compose/ext-src/online_advisor-src/neon-test.sh b/docker-compose/ext-src/online_advisor-src/neon-test.sh new file mode 100755 index 0000000000..db5c2821fa --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/neon-test.sh @@ -0,0 +1,6 @@ +#!/bin/sh +set -ex +cd "$(dirname "${0}")" +if [ -f Makefile ]; then + make installcheck +fi diff --git a/docker-compose/ext-src/online_advisor-src/regular-test.sh b/docker-compose/ext-src/online_advisor-src/regular-test.sh new file mode 100755 index 0000000000..e94f03aa70 --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/regular-test.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +[ -f Makefile ] || exit 0 +dropdb --if-exist contrib_regression +createdb contrib_regression +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} diff --git a/docker-compose/test_extensions_upgrade.sh b/docker-compose/test_extensions_upgrade.sh index 51d1e40802..f1cf17f531 100755 --- a/docker-compose/test_extensions_upgrade.sh +++ b/docker-compose/test_extensions_upgrade.sh @@ -82,7 +82,8 @@ EXTENSIONS='[ {"extname": "pg_ivm", "extdir": "pg_ivm-src"}, {"extname": "pgjwt", "extdir": "pgjwt-src"}, {"extname": "pgtap", "extdir": "pgtap-src"}, -{"extname": "pg_repack", "extdir": "pg_repack-src"} +{"extname": "pg_repack", "extdir": "pg_repack-src"}, +{"extname": "h3", "extdir": "h3-pg-src"} ]' EXTNAMES=$(echo ${EXTENSIONS} | jq -r '.[].extname' | paste -sd ' ' -) COMPUTE_TAG=${NEW_COMPUTE_TAG} docker compose --profile test-extensions up --quiet-pull --build -d diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 4df8d7bc51..5d028ee041 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -27,6 +27,7 @@ pub use prometheus::{ pub mod launch_timestamp; mod wrappers; +pub use prometheus; pub use wrappers::{CountedReader, CountedWriter}; mod hll; pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec}; diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 0fb2ff38ff..444983bd18 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -8,6 +8,8 @@ pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000; pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}"); pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898; pub const DEFAULT_HTTP_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_HTTP_LISTEN_PORT}"); +// TODO: gRPC is disabled by default for now, but the port is used in neon_local. +pub const DEFAULT_GRPC_LISTEN_PORT: u16 = 51051; // storage-broker already uses 50051 use std::collections::HashMap; use std::num::{NonZeroU64, NonZeroUsize}; @@ -43,6 +45,21 @@ pub struct NodeMetadata { pub other: HashMap, } +/// PostHog integration config. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct PostHogConfig { + /// PostHog project ID + pub project_id: String, + /// Server-side (private) API key + pub server_api_key: String, + /// Client-side (public) API key + pub client_api_key: String, + /// Private API URL + pub private_api_url: String, + /// Public API URL + pub public_api_url: String, +} + /// `pageserver.toml` /// /// We use serde derive with `#[serde(default)]` to generate a deserializer @@ -104,6 +121,7 @@ pub struct ConfigToml { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub ssl_key_file: Utf8PathBuf, pub ssl_cert_file: Utf8PathBuf, #[serde(with = "humantime_serde")] @@ -123,6 +141,7 @@ pub struct ConfigToml { pub http_auth_type: AuthType, #[serde_as(as = "serde_with::DisplayFromStr")] pub pg_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub auth_validation_public_key_path: Option, pub remote_storage: Option, pub tenant_config: TenantConfigToml, @@ -182,6 +201,8 @@ pub struct ConfigToml { pub tracing: Option, pub enable_tls_page_service_api: bool, pub dev_mode: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub posthog_config: Option, pub timeline_import_config: TimelineImportConfig, #[serde(skip_serializing_if = "Option::is_none")] pub basebackup_cache_config: Option, @@ -588,6 +609,7 @@ impl Default for ConfigToml { listen_pg_addr: (DEFAULT_PG_LISTEN_ADDR.to_string()), listen_http_addr: (DEFAULT_HTTP_LISTEN_ADDR.to_string()), listen_https_addr: (None), + listen_grpc_addr: None, // TODO: default to 127.0.0.1:51051 ssl_key_file: Utf8PathBuf::from(DEFAULT_SSL_KEY_FILE), ssl_cert_file: Utf8PathBuf::from(DEFAULT_SSL_CERT_FILE), ssl_cert_reload_period: Duration::from_secs(60), @@ -604,6 +626,7 @@ impl Default for ConfigToml { pg_distrib_dir: None, // Utf8PathBuf::from("./pg_install"), // TODO: formely, this was std::env::current_dir() http_auth_type: (AuthType::Trust), pg_auth_type: (AuthType::Trust), + grpc_auth_type: (AuthType::Trust), auth_validation_public_key_path: (None), remote_storage: None, broker_endpoint: (storage_broker::DEFAULT_ENDPOINT @@ -690,11 +713,12 @@ impl Default for ConfigToml { enable_tls_page_service_api: false, dev_mode: false, timeline_import_config: TimelineImportConfig { - import_job_concurrency: NonZeroUsize::new(128).unwrap(), - import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(), - import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), + import_job_concurrency: NonZeroUsize::new(32).unwrap(), + import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(), + import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(), }, basebackup_cache_config: None, + posthog_config: None, } } } diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 383939a13f..28ced4a368 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -354,6 +354,9 @@ pub struct ShardImportProgressV1 { pub completed: usize, /// Hash of the plan pub import_plan_hash: u64, + /// Soft limit for the job size + /// This needs to remain constant throughout the import + pub job_soft_size_limit: usize, } impl ShardImportStatus { @@ -402,6 +405,8 @@ pub enum TimelineCreateRequestMode { // using a flattened enum, so, it was an accepted field, and // we continue to accept it by having it here. pg_version: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + read_only: bool, }, ImportPgdata { import_pgdata: TimelineCreateRequestModeImportPgdata, @@ -1929,7 +1934,7 @@ pub enum PagestreamFeMessage { } // Wrapped in libpq CopyData -#[derive(strum_macros::EnumProperty)] +#[derive(Debug, strum_macros::EnumProperty)] pub enum PagestreamBeMessage { Exists(PagestreamExistsResponse), Nblocks(PagestreamNblocksResponse), @@ -2040,7 +2045,7 @@ pub enum PagestreamProtocolVersion { pub type RequestId = u64; -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamRequest { pub reqid: RequestId, pub request_lsn: Lsn, @@ -2059,7 +2064,7 @@ pub struct PagestreamNblocksRequest { pub rel: RelTag, } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct PagestreamGetPageRequest { pub hdr: PagestreamRequest, pub rel: RelTag, diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 473a44dbf9..e0dd4fdfe8 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; // FIXME: should move 'forknum' as last field to keep this consistent with Postgres. // Then we could replace the custom Ord and PartialOrd implementations below with // deriving them. This will require changes in walredoproc.c. -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)] pub struct RelTag { pub forknum: u8, pub spcnode: Oid, @@ -184,12 +184,12 @@ pub enum SlruKind { MultiXactOffsets, } -impl SlruKind { - pub fn to_str(&self) -> &'static str { +impl fmt::Display for SlruKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Clog => "pg_xact", - Self::MultiXactMembers => "pg_multixact/members", - Self::MultiXactOffsets => "pg_multixact/offsets", + Self::Clog => write!(f, "pg_xact"), + Self::MultiXactMembers => write!(f, "pg_multixact/members"), + Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"), } } } diff --git a/libs/posthog_client_lite/Cargo.toml b/libs/posthog_client_lite/Cargo.toml index 7c19bf2ccb..05a3a9774e 100644 --- a/libs/posthog_client_lite/Cargo.toml +++ b/libs/posthog_client_lite/Cargo.toml @@ -6,9 +6,14 @@ license.workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true reqwest.workspace = true -serde.workspace = true serde_json.workspace = true +serde.workspace = true sha2.workspace = true -workspace_hack.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } +tokio-util.workspace = true +tracing-utils.workspace = true +tracing.workspace = true +workspace_hack.workspace = true diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs new file mode 100644 index 0000000000..a05f6096b1 --- /dev/null +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -0,0 +1,64 @@ +//! A background loop that fetches feature flags from PostHog and updates the feature store. + +use std::{sync::Arc, time::Duration}; + +use arc_swap::ArcSwap; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, info_span}; + +use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; + +/// A background loop that fetches feature flags from PostHog and updates the feature store. +pub struct FeatureResolverBackgroundLoop { + posthog_client: PostHogClient, + feature_store: ArcSwap, + cancel: CancellationToken, +} + +impl FeatureResolverBackgroundLoop { + pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self { + Self { + posthog_client: PostHogClient::new(config), + feature_store: ArcSwap::new(Arc::new(FeatureStore::new())), + cancel: shutdown_pageserver, + } + } + + pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { + let this = self.clone(); + let cancel = self.cancel.clone(); + handle.spawn( + async move { + tracing::info!("Starting PostHog feature resolver"); + let mut ticker = tokio::time::interval(refresh_period); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + tokio::select! { + _ = ticker.tick() => {} + _ = cancel.cancelled() => break + } + let resp = match this + .posthog_client + .get_feature_flags_local_evaluation() + .await + { + Ok(resp) => resp, + Err(e) => { + tracing::warn!("Cannot get feature flags: {}", e); + continue; + } + }; + let feature_store = FeatureStore::new_with_flags(resp.flags); + this.feature_store.store(Arc::new(feature_store)); + tracing::info!("Feature flag updated"); + } + tracing::info!("PostHog feature resolver stopped"); + } + .instrument(info_span!("posthog_feature_resolver")), + ); + } + + pub fn feature_store(&self) -> Arc { + self.feature_store.load_full() + } +} diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 53deb26ab7..ff12051196 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -1,5 +1,9 @@ //! A lite version of the PostHog client that only supports local evaluation of feature flags. +mod background_loop; + +pub use background_loop::FeatureResolverBackgroundLoop; + use std::collections::HashMap; use serde::{Deserialize, Serialize}; @@ -20,8 +24,7 @@ pub enum PostHogEvaluationError { #[derive(Deserialize)] pub struct LocalEvaluationResponse { - #[allow(dead_code)] - flags: Vec, + pub flags: Vec, } #[derive(Deserialize)] @@ -34,7 +37,7 @@ pub struct LocalEvaluationFlag { #[derive(Deserialize)] pub struct LocalEvaluationFlagFilters { groups: Vec, - multivariate: LocalEvaluationFlagMultivariate, + multivariate: Option, } #[derive(Deserialize)] @@ -94,6 +97,12 @@ impl FeatureStore { } } + pub fn new_with_flags(flags: Vec) -> Self { + let mut store = Self::new(); + store.set_flags(flags); + store + } + pub fn set_flags(&mut self, flags: Vec) { self.flags.clear(); for flag in flags { @@ -245,7 +254,7 @@ impl FeatureStore { } } - /// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors + /// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors /// during the evaluation. /// /// The parsing logic is as follows: @@ -263,10 +272,15 @@ impl FeatureStore { /// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%). /// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override. /// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. pub fn evaluate_multivariate( &self, flag_key: &str, user_id: &str, + properties: &HashMap, ) -> Result { let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "multivariate"); @@ -276,10 +290,39 @@ impl FeatureStore { flag_key, hash_on_global_rollout_percentage, hash_on_group_rollout_percentage, - &HashMap::new(), + properties, ) } + /// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors + /// during the evaluation. + /// + /// The parsing logic is as follows: + /// + /// * Generate a consistent hash for the tenant-feature. + /// * Match each filter group. + /// - If a group is matched, it will first determine whether the user is in the range of the rollout + /// percentage. + /// - If the hash falls within the group's rollout percentage, return true. + /// * Otherwise, continue with the next group until all groups are evaluated and no group is within the + /// rollout percentage. + /// * If there are no matching groups, return an error. + /// + /// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + user_id: &str, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean"); + self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties) + } + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests /// and avoid duplicate computations. @@ -306,6 +349,11 @@ impl FeatureStore { flag_key ))); } + let Some(ref multivariate) = flag_config.filters.multivariate else { + return Err(PostHogEvaluationError::Internal(format!( + "No multivariate available, should use evaluate_boolean?: {flag_key}" + ))); + }; // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it // does not matter. @@ -314,7 +362,7 @@ impl FeatureStore { GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant), GroupEvaluationResult::MatchedAndEvaluate => { let mut percentage = 0; - for variant in &flag_config.filters.multivariate.variants { + for variant in &multivariate.variants { percentage += variant.rollout_percentage; if self .evaluate_percentage(hash_on_global_rollout_percentage, percentage) @@ -342,6 +390,89 @@ impl FeatureStore { ))) } } + + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID + /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests + /// and avoid duplicate computations. + /// + /// Use a different consistent hash for evaluating the group rollout percentage. + /// The behavior: if the condition is set to rolling out to 10% of the users, and + /// we set the variant A to 20% in the global config, then 2% of the total users will + /// be evaluated to variant A. + /// + /// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two + /// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users + /// will be evaluated (versus 30% if group evaluation is done independently). + pub(crate) fn evaluate_boolean_inner( + &self, + flag_key: &str, + hash_on_global_rollout_percentage: f64, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + if let Some(flag_config) = self.flags.get(flag_key) { + if !flag_config.active { + return Err(PostHogEvaluationError::NotAvailable(format!( + "The feature flag is not active: {}", + flag_key + ))); + } + if flag_config.filters.multivariate.is_some() { + return Err(PostHogEvaluationError::Internal(format!( + "This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}" + ))); + }; + // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog + // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it + // does not matter. + for group in &flag_config.filters.groups { + match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? { + GroupEvaluationResult::MatchedAndOverride(_) => { + return Err(PostHogEvaluationError::Internal(format!( + "Boolean flag cannot have overrides: {}", + flag_key + ))); + } + GroupEvaluationResult::MatchedAndEvaluate => { + return Ok(()); + } + GroupEvaluationResult::Unmatched => continue, + } + } + // If no group is matched, the feature is not available, and up to the caller to decide what to do. + Err(PostHogEvaluationError::NoConditionGroupMatched) + } else { + // The feature flag is not available yet + Err(PostHogEvaluationError::NotAvailable(format!( + "Not found in the local evaluation spec: {}", + flag_key + ))) + } + } + + /// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter. + pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result { + if let Some(flag_config) = self.flags.get(flag_key) { + Ok(flag_config.filters.multivariate.is_none()) + } else { + Err(PostHogEvaluationError::NotAvailable(format!( + "Not found in the local evaluation spec: {}", + flag_key + ))) + } + } +} + +pub struct PostHogClientConfig { + /// The server API key. + pub server_api_key: String, + /// The client API key. + pub client_api_key: String, + /// The project ID. + pub project_id: String, + /// The private API URL. + pub private_api_url: String, + /// The public API URL. + pub public_api_url: String, } /// A lite PostHog client. @@ -360,37 +491,16 @@ impl FeatureStore { /// want to report the feature flag usage back to PostHog. The current plan is to use PostHog only as an UI to /// configure feature flags so it is very likely that the client API will not be used. pub struct PostHogClient { - /// The server API key. - server_api_key: String, - /// The client API key. - client_api_key: String, - /// The project ID. - project_id: String, - /// The private API URL. - private_api_url: String, - /// The public API URL. - public_api_url: String, + /// The config. + config: PostHogClientConfig, /// The HTTP client. client: reqwest::Client, } impl PostHogClient { - pub fn new( - server_api_key: String, - client_api_key: String, - project_id: String, - private_api_url: String, - public_api_url: String, - ) -> Self { + pub fn new(config: PostHogClientConfig) -> Self { let client = reqwest::Client::new(); - Self { - server_api_key, - client_api_key, - project_id, - private_api_url, - public_api_url, - client, - } + Self { config, client } } pub fn new_with_us_region( @@ -398,13 +508,13 @@ impl PostHogClient { client_api_key: String, project_id: String, ) -> Self { - Self::new( + Self::new(PostHogClientConfig { server_api_key, client_api_key, project_id, - "https://us.posthog.com".to_string(), - "https://us.i.posthog.com".to_string(), - ) + private_api_url: "https://us.posthog.com".to_string(), + public_api_url: "https://us.i.posthog.com".to_string(), + }) } /// Fetch the feature flag specs from the server. @@ -422,15 +532,23 @@ impl PostHogClient { // with bearer token of self.server_api_key let url = format!( "{}/api/projects/{}/feature_flags/local_evaluation", - self.private_api_url, self.project_id + self.config.private_api_url, self.config.project_id ); let response = self .client .get(url) - .bearer_auth(&self.server_api_key) + .bearer_auth(&self.config.server_api_key) .send() .await?; + let status = response.status(); let body = response.text().await?; + if !status.is_success() { + return Err(anyhow::anyhow!( + "Failed to get feature flags: {}, {}", + status, + body + )); + } Ok(serde_json::from_str(&body)?) } @@ -446,11 +564,11 @@ impl PostHogClient { ) -> anyhow::Result<()> { // PUBLIC_URL/capture/ // with bearer token of self.client_api_key - let url = format!("{}/capture/", self.public_api_url); + let url = format!("{}/capture/", self.config.public_api_url); self.client .post(url) .body(serde_json::to_string(&json!({ - "api_key": self.client_api_key, + "api_key": self.config.client_api_key, "distinct_id": distinct_id, "event": event, "properties": properties, @@ -467,95 +585,162 @@ mod tests { fn data() -> &'static str { r#"{ - "flags": [ - { - "id": 132794, - "team_id": 152860, - "name": "", - "key": "gc-compaction", - "filters": { - "groups": [ - { - "variant": "enabled-stage-2", - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 50 - }, - { - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 80 - } - ], - "payloads": {}, - "multivariate": { - "variants": [ - { - "key": "disabled", - "name": "", - "rollout_percentage": 90 - }, - { - "key": "enabled-stage-1", - "name": "", - "rollout_percentage": 10 - }, - { - "key": "enabled-stage-2", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled-stage-3", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled", - "name": "", - "rollout_percentage": 0 - } - ] - } - }, - "deleted": false, - "active": true, - "ensure_experience_continuity": false, - "has_encrypted_payloads": false, - "version": 6 - } + "flags": [ + { + "id": 141807, + "team_id": 152860, + "name": "", + "key": "image-compaction-boundary", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } ], - "group_type_mapping": {}, - "cohorts": {} - }"# + "rollout_percentage": 40 + }, + { + "variant": null, + "properties": [], + "rollout_percentage": 10 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 135586, + "team_id": 152860, + "name": "", + "key": "boolean-flag", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } + ], + "rollout_percentage": 47 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 132794, + "team_id": 152860, + "name": "", + "key": "gc-compaction", + "filters": { + "groups": [ + { + "variant": "enabled-stage-2", + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 50 + }, + { + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 80 + } + ], + "payloads": {}, + "multivariate": { + "variants": [ + { + "key": "disabled", + "name": "", + "rollout_percentage": 90 + }, + { + "key": "enabled-stage-1", + "name": "", + "rollout_percentage": 10 + }, + { + "key": "enabled-stage-2", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled-stage-3", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled", + "name": "", + "rollout_percentage": 0 + } + ] + } + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 7 + } + ], + "group_type_mapping": {}, + "cohorts": {} +}"# } #[test] @@ -631,4 +816,125 @@ mod tests { Err(PostHogEvaluationError::NoConditionGroupMatched) ),); } + + #[test] + fn evaluate_boolean_1() { + // The `boolean-flag` feature flag only has one group that matches on the free user. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match any group so there will be an error. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + } + + #[test] + fn evaluate_boolean_2() { + // The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match the filtered group but the all user group. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched); + assert!(variant.is_ok()); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + // It matches the second "all" group conditions. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties); + assert!(variant.is_ok()); + } } diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index 206b8bbd8f..11f787562c 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -73,6 +73,7 @@ pub mod error; /// async timeout helper pub mod timeout; +pub mod span; pub mod sync; pub mod failpoint_support; diff --git a/libs/utils/src/span.rs b/libs/utils/src/span.rs new file mode 100644 index 0000000000..4dbc99044b --- /dev/null +++ b/libs/utils/src/span.rs @@ -0,0 +1,19 @@ +//! Tracing span helpers. + +/// Records the given fields in the current span, as a single call. The fields must already have +/// been declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record { + ($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)}; +} + +/// Records the given fields in the given span, as a single call. The fields must already have been +/// declared for the span (typically with empty values). +#[macro_export] +macro_rules! span_record_in { + ($span:expr, $($tokens:tt)*) => { + if let Some(meta) = $span.metadata() { + $span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*)); + } + }; +} diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 4e50c21fca..e95494297c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -1,6 +1,7 @@ #![allow(clippy::todo)] use std::ffi::CString; +use std::str::FromStr; use postgres_ffi::WAL_SEGMENT_SIZE; use utils::id::TenantTimelineId; @@ -173,6 +174,8 @@ pub struct Config { pub ttid: TenantTimelineId, /// List of safekeepers in format `host:port` pub safekeepers_list: Vec, + /// libpq connection info options + pub safekeeper_conninfo_options: String, /// Safekeeper reconnect timeout in milliseconds pub safekeeper_reconnect_timeout: i32, /// Safekeeper connection timeout in milliseconds @@ -202,6 +205,9 @@ impl Wrapper { .into_bytes_with_nul(); assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity()); let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char; + let safekeeper_conninfo_options = CString::from_str(&config.safekeeper_conninfo_options) + .unwrap() + .into_raw(); let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void; @@ -209,6 +215,7 @@ impl Wrapper { neon_tenant, neon_timeline, safekeepers_list, + safekeeper_conninfo_options, safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout, safekeeper_connection_timeout: config.safekeeper_connection_timeout, wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB @@ -576,6 +583,7 @@ mod tests { let config = crate::walproposer::Config { ttid, safekeepers_list: vec!["localhost:5000".to_string()], + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 10000, sync_safekeepers: true, diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 5500d4ec8d..aac454acf2 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -17,53 +17,72 @@ anyhow.workspace = true arc-swap.workspace = true async-compression.workspace = true async-stream.workspace = true -bit_field.workspace = true bincode.workspace = true +bit_field.workspace = true byteorder.workspace = true bytes.workspace = true -camino.workspace = true camino-tempfile.workspace = true +camino.workspace = true chrono = { workspace = true, features = ["serde"] } clap = { workspace = true, features = ["string"] } consumption_metrics.workspace = true crc32c.workspace = true either.workspace = true +enum-map.workspace = true +enumset = { workspace = true, features = ["serde"]} fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true -humantime.workspace = true +http.workspace = true +http-utils.workspace = true humantime-serde.workspace = true +humantime.workspace = true hyper0.workspace = true itertools.workspace = true jsonwebtoken.workspace = true md5.workspace = true +metrics.workspace = true nix.workspace = true -# hack to get the number of worker threads tokio uses -num_cpus.workspace = true +num_cpus.workspace = true # hack to get the number of worker threads tokio uses num-traits.workspace = true once_cell.workspace = true +pageserver_api.workspace = true +pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that +pageserver_compaction.workspace = true +pageserver_page_api.workspace = true peekable.workspace = true +pem.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres_connection.workspace = true +postgres_ffi.workspace = true +postgres_initdb.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true -postgres_initdb.workspace = true +posthog_client_lite.workspace = true pprof.workspace = true +pq_proto.workspace = true prost.workspace = true rand.workspace = true range-set-blaze = { version = "0.1.16", features = ["alloc"] } regex.workspace = true +remote_storage.workspace = true +reqwest.workspace = true +rpds.workspace = true rustls.workspace = true scopeguard.workspace = true send-future.workspace = true -serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } serde_path_to_error.workspace = true serde_with.workspace = true +serde.workspace = true +smallvec.workspace = true +storage_broker.workspace = true +strum_macros.workspace = true +strum.workspace = true sysinfo.workspace = true -tokio-tar.workspace = true -tonic.workspace = true +tenant_size_model.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } @@ -72,35 +91,19 @@ tokio-io-timeout.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true tokio-stream.workspace = true +tokio-tar.workspace = true tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } +tonic.workspace = true +tonic-reflection.workspace = true +tower.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true -walkdir.workspace = true -metrics.workspace = true -pageserver_api.workspace = true -pageserver_page_api.workspace = true -pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that -pageserver_compaction.workspace = true -pem.workspace = true -postgres_connection.workspace = true -postgres_ffi.workspace = true -pq_proto.workspace = true -remote_storage.workspace = true -storage_broker.workspace = true -tenant_size_model.workspace = true -http-utils.workspace = true utils.workspace = true -workspace_hack.workspace = true -reqwest.workspace = true -rpds.workspace = true -enum-map.workspace = true -enumset = { workspace = true, features = ["serde"]} -strum.workspace = true -strum_macros.workspace = true wal_decoder.workspace = true -smallvec.workspace = true +walkdir.workspace = true +workspace_hack.workspace = true twox-hash.workspace = true [target.'cfg(target_os = "linux")'.dependencies] diff --git a/pageserver/benches/bench_metrics.rs b/pageserver/benches/bench_metrics.rs index 38025124e1..e0428f6372 100644 --- a/pageserver/benches/bench_metrics.rs +++ b/pageserver/benches/bench_metrics.rs @@ -264,10 +264,56 @@ mod propagation_of_cached_label_value { } } +criterion_group!(histograms, histograms::bench_bucket_scalability); +mod histograms { + use std::time::Instant; + + use criterion::{BenchmarkId, Criterion}; + use metrics::core::Collector; + + pub fn bench_bucket_scalability(c: &mut Criterion) { + let mut g = c.benchmark_group("bucket_scalability"); + + for n in [1, 4, 8, 16, 32, 64, 128, 256] { + g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| { + b.iter_custom(|iters| { + let buckets: Vec = (0..*n).map(|i| i as f64 * 100.0).collect(); + let histo = metrics::Histogram::with_opts( + metrics::prometheus::HistogramOpts::new("name", "help") + .buckets(buckets.clone()), + ) + .unwrap(); + let start = Instant::now(); + for i in 0..usize::try_from(iters).unwrap() { + histo.observe(buckets[i % buckets.len()]); + } + let elapsed = start.elapsed(); + // self-test + let mfs = histo.collect(); + assert_eq!(mfs.len(), 1); + let metrics = mfs[0].get_metric(); + assert_eq!(metrics.len(), 1); + let histo = metrics[0].get_histogram(); + let buckets = histo.get_bucket(); + assert!( + buckets + .iter() + .enumerate() + .all(|(i, b)| b.get_cumulative_count() + >= i as u64 * (iters / buckets.len() as u64)) + ); + elapsed + }) + }); + } + } +} + criterion_main!( label_values, single_metric_multicore_scalability, - propagation_of_cached_label_value + propagation_of_cached_label_value, + histograms, ); /* @@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns] +bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns] +bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns] +bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns] +bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns] +bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns] +bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns] +bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns] +bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns] Results on an i3en.3xlarge instance @@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns] +bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns] +bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns] +bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns] +bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns] +bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns] +bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns] +bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns] +bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns] Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor @@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1 propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns] propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns] +bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns] +bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns] +bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns] +bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns] +bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns] +bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns] +bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns] +bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns] */ diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index d005cddc3f..3d6de1f900 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -13,8 +13,8 @@ use futures::{Stream, StreamExt}; use thiserror::Error; use tonic::metadata::AsciiMetadataValue; -use pageserver_page_api::model::*; use pageserver_page_api::proto; +use pageserver_page_api::*; use pageserver_page_api::proto::PageServiceClient; use utils::shard::ShardIndex; @@ -146,7 +146,7 @@ impl PageserverClient { } pub async fn process_check_rel_exists_request( &self, - request: &CheckRelExistsRequest, + request: CheckRelExistsRequest, ) -> Result { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); @@ -156,7 +156,7 @@ impl PageserverClient { let mut client = PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - let request = proto::CheckRelExistsRequest::from(request); + let request = proto::CheckRelExistsRequest::try_from(request)?; let response = client.check_rel_exists(tonic::Request::new(request)).await; match response { @@ -173,7 +173,7 @@ impl PageserverClient { pub async fn process_get_rel_size_request( &self, - request: &GetRelSizeRequest, + request: GetRelSizeRequest, ) -> Result { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); @@ -183,7 +183,7 @@ impl PageserverClient { let mut client = PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - let request = proto::GetRelSizeRequest::from(request); + let request = proto::GetRelSizeRequest::try_from(request)?; let response = client.get_rel_size(tonic::Request::new(request)).await; match response { @@ -203,7 +203,7 @@ impl PageserverClient { // TODO: This opens a new gRPC stream for every request, which is extremely inefficient pub async fn get_page( &self, - request: &GetPageRequest, + request: GetPageRequest, ) -> Result, PageserverClientError> { // FIXME: calculate the shard number correctly let shard = ShardIndex::unsharded(); @@ -213,7 +213,7 @@ impl PageserverClient { let mut client = PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - let request = proto::GetPageRequest::from(request); + let request = proto::GetPageRequest::try_from(request)?; let request_stream = futures::stream::once(std::future::ready(request)); @@ -245,8 +245,8 @@ impl PageserverClient { } Ok(resp) => { pooled_client.finish(Ok(())).await; // Pass success to finish - let response: GetPageResponse = resp.try_into()?; - return Ok(response.page_image); + let response: GetPageResponse = resp.into(); + return Ok(response.page_images.to_vec()); } } } @@ -286,7 +286,7 @@ impl PageserverClient { /// Process a request to get the size of a database. pub async fn process_get_dbsize_request( &self, - request: &GetDbSizeRequest, + request: GetDbSizeRequest, ) -> Result { // Current sharding model assumes that all metadata is present only at shard 0. let shard = ShardIndex::unsharded(); @@ -296,7 +296,7 @@ impl PageserverClient { let mut client = PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard)); - let request = proto::GetDbSizeRequest::from(request); + let request = proto::GetDbSizeRequest::try_from(request)?; let response = client.get_db_size(tonic::Request::new(request)).await; match response { @@ -313,7 +313,7 @@ impl PageserverClient { /// Process a request to get the size of a database. pub async fn get_base_backup( &self, - request: &GetBaseBackupRequest, + request: GetBaseBackupRequest, gzip: bool, ) -> std::result::Result< tonic::Response>, @@ -331,7 +331,7 @@ impl PageserverClient { client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip); } - let request = proto::GetBaseBackupRequest::from(request); + let request = proto::GetBaseBackupRequest::try_from(request)?; let response = client.get_base_backup(tonic::Request::new(request)).await; match response { diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index 3a17981a78..4f62c77eb2 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -6,10 +6,12 @@ license.workspace = true [dependencies] bytes.workspace = true +pageserver_api.workspace = true +postgres_ffi.workspace = true prost.workspace = true +smallvec.workspace = true thiserror.workspace = true tonic.workspace = true - utils.workspace = true workspace_hack.workspace = true diff --git a/pageserver/page_api/proto/page_service.proto b/pageserver/page_api/proto/page_service.proto index f6acb3eeeb..44976084bf 100644 --- a/pageserver/page_api/proto/page_service.proto +++ b/pageserver/page_api/proto/page_service.proto @@ -54,9 +54,9 @@ service PageService { // RPCs use regular unary requests, since they are not as frequent and // performance-critical, and this simplifies implementation. // - // NB: a status response (e.g. errors) will terminate the stream. The stream - // may be shared by e.g. multiple Postgres backends, so we should avoid this. - // Most errors are therefore sent as GetPageResponse.status instead. + // NB: a gRPC status response (e.g. errors) will terminate the stream. The + // stream may be shared by multiple Postgres backends, so we avoid this by + // sending them as GetPageResponse.status_code instead. rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse); // Returns the size of a relation, as # of blocks. @@ -159,8 +159,8 @@ message GetPageRequest { // A GetPageRequest class. Primarily intended for observability, but may also be // used for prioritization in the future. enum GetPageClass { - // Unknown class. For forwards compatibility: used when the client sends a - // class that the server doesn't know about. + // Unknown class. For backwards compatibility: used when an older client version sends a class + // that a newer server version has removed. GET_PAGE_CLASS_UNKNOWN = 0; // A normal request. This is the default. GET_PAGE_CLASS_NORMAL = 1; @@ -180,31 +180,37 @@ message GetPageResponse { // The original request's ID. uint64 request_id = 1; // The response status code. - GetPageStatus status = 2; + GetPageStatusCode status_code = 2; // A string describing the status, if any. string reason = 3; - // The 8KB page images, in the same order as the request. Empty if status != OK. + // The 8KB page images, in the same order as the request. Empty if status_code != OK. repeated bytes page_image = 4; } -// A GetPageResponse status code. Since we use a bidirectional stream, we don't -// want to send errors as gRPC statuses, since this would terminate the stream. -enum GetPageStatus { - // Unknown status. For forwards compatibility: used when the server sends a - // status code that the client doesn't know about. - GET_PAGE_STATUS_UNKNOWN = 0; +// A GetPageResponse status code. +// +// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +// (potentially shared by many backends), and a gRPC status response would terminate the stream so +// we send GetPageResponse messages with these codes instead. +enum GetPageStatusCode { + // Unknown status. For forwards compatibility: used when an older client version receives a new + // status code from a newer server version. + GET_PAGE_STATUS_CODE_UNKNOWN = 0; // The request was successful. - GET_PAGE_STATUS_OK = 1; + GET_PAGE_STATUS_CODE_OK = 1; // The page did not exist. The tenant/timeline/shard has already been // validated during stream setup. - GET_PAGE_STATUS_NOT_FOUND = 2; + GET_PAGE_STATUS_CODE_NOT_FOUND = 2; // The request was invalid. - GET_PAGE_STATUS_INVALID = 3; + GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3; + // The request failed due to an internal server error. + GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4; // The tenant is rate limited. Slow down and retry later. - GET_PAGE_STATUS_SLOW_DOWN = 4; - // TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a - // layer download. This could free up the server task to process other - // requests while the layer download is in progress. + GET_PAGE_STATUS_CODE_SLOW_DOWN = 5; + // NB: shutdown errors are emitted as a gRPC Unavailable status. + // + // TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download. + // This could free up the server task to process other requests while the download is in progress. } // Fetches the size of a relation at a given LSN, as # of blocks. Only valid on diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index 4cbaf40763..f515f27f3e 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -5,8 +5,6 @@ //! //! This crate is used by both the client and the server. Try to keep it slim. -pub mod model; - // Code generated by protobuf. pub mod proto { tonic::include_proto!("page_api"); @@ -19,3 +17,7 @@ pub mod proto { pub use page_service_client::PageServiceClient; pub use page_service_server::{PageService, PageServiceServer}; } + +mod model; + +pub use model::*; diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 6d24d6e2ba..7e7d0eb32d 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -1,364 +1,621 @@ //! Structs representing the canonical page service API. //! -//! These mirror the pageserver APIs and the structs automatically generated -//! from the protobuf specification. The differences are: +//! These mirror the autogenerated Protobuf types. The differences are: //! //! - Types that are in fact required by the API are not Options. The protobuf "required" //! attribute is deprecated and 'prost' marks a lot of members as optional because of that. -//! (See https://github.com/tokio-rs/prost/issues/800 for a gripe on this) +//! (See for a gripe on this) //! //! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits. //! -//! TODO: these types should be used in the Pageserver for actual processing, -//! instead of being cast into internal mirror types. +//! - Validate protocol invariants, via try_from() and try_into(). + +use std::fmt::Display; use bytes::Bytes; +use postgres_ffi::Oid; +use smallvec::SmallVec; +// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid +// pulling in all of their other crate dependencies when building the client. use utils::lsn::Lsn; use crate::proto; -#[derive(Clone, Debug)] +/// A protocol error. Typically returned via try_from() or try_into(). +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("field '{0}' has invalid value '{1}'")] + Invalid(&'static str, String), + #[error("required field '{0}' is missing")] + Missing(&'static str), +} + +impl ProtocolError { + /// Helper to generate a new ProtocolError::Invalid for the given field and value. + pub fn invalid(field: &'static str, value: impl std::fmt::Debug) -> Self { + Self::Invalid(field, format!("{value:?}")) + } +} + +impl From for tonic::Status { + fn from(err: ProtocolError) -> Self { + tonic::Status::invalid_argument(format!("{err}")) + } +} + +/// The LSN a request should read at. +#[derive(Clone, Copy, Debug)] pub struct ReadLsn { + /// The request's read LSN. pub request_lsn: Lsn, - pub not_modified_since_lsn: Lsn, + /// If given, the caller guarantees that the page has not been modified since this LSN. Must be + /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page + /// without waiting for the request LSN to arrive. If not given, the request will read at the + /// request_lsn and wait for it to arrive if necessary. Valid for all request types. + /// + /// It is undefined behaviour to make a request such that the page was, in fact, modified + /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an + /// error, or it might return the old page version or the new page version. Setting + /// not_modified_since_lsn equal to request_lsn is always safe, but can lead to unnecessary + /// waiting. + pub not_modified_since_lsn: Option, } -#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] -pub struct RelTag { - pub spc_oid: u32, - pub db_oid: u32, - pub rel_number: u32, - pub fork_number: u8, +impl Display for ReadLsn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let req_lsn = self.request_lsn; + if let Some(mod_lsn) = self.not_modified_since_lsn { + write!(f, "{req_lsn}>={mod_lsn}") + } else { + req_lsn.fmt(f) + } + } } -#[derive(Clone, Debug)] +impl ReadLsn { + /// Validates the ReadLsn. + pub fn validate(&self) -> Result<(), ProtocolError> { + if self.request_lsn == Lsn::INVALID { + return Err(ProtocolError::invalid("request_lsn", self.request_lsn)); + } + if self.not_modified_since_lsn > Some(self.request_lsn) { + return Err(ProtocolError::invalid( + "not_modified_since_lsn", + self.not_modified_since_lsn, + )); + } + Ok(()) + } +} + +impl TryFrom for ReadLsn { + type Error = ProtocolError; + + fn try_from(pb: proto::ReadLsn) -> Result { + let read_lsn = Self { + request_lsn: Lsn(pb.request_lsn), + not_modified_since_lsn: match pb.not_modified_since_lsn { + 0 => None, + lsn => Some(Lsn(lsn)), + }, + }; + read_lsn.validate()?; + Ok(read_lsn) + } +} + +impl TryFrom for proto::ReadLsn { + type Error = ProtocolError; + + fn try_from(read_lsn: ReadLsn) -> Result { + read_lsn.validate()?; + Ok(Self { + request_lsn: read_lsn.request_lsn.0, + not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0, + }) + } +} + +// RelTag is defined in pageserver_api::reltag. +pub type RelTag = pageserver_api::reltag::RelTag; + +impl TryFrom for RelTag { + type Error = ProtocolError; + + fn try_from(pb: proto::RelTag) -> Result { + Ok(Self { + spcnode: pb.spc_oid, + dbnode: pb.db_oid, + relnode: pb.rel_number, + forknum: pb + .fork_number + .try_into() + .map_err(|_| ProtocolError::invalid("fork_number", pb.fork_number))?, + }) + } +} + +impl From for proto::RelTag { + fn from(rel_tag: RelTag) -> Self { + Self { + spc_oid: rel_tag.spcnode, + db_oid: rel_tag.dbnode, + rel_number: rel_tag.relnode, + fork_number: rel_tag.forknum as u32, + } + } +} + +/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error. +#[derive(Clone, Copy, Debug)] pub struct CheckRelExistsRequest { pub read_lsn: ReadLsn, pub rel: RelTag, } +impl TryFrom for CheckRelExistsRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::CheckRelExistsRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } +} + +impl TryFrom for proto::CheckRelExistsRequest { + type Error = ProtocolError; + + fn try_from(request: CheckRelExistsRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + }) + } +} + +pub type CheckRelExistsResponse = bool; + +impl From for CheckRelExistsResponse { + fn from(pb: proto::CheckRelExistsResponse) -> Self { + pb.exists + } +} + +impl From for proto::CheckRelExistsResponse { + fn from(exists: CheckRelExistsResponse) -> Self { + Self { exists } + } +} + +/// Requests a base backup at a given LSN. +#[derive(Clone, Copy, Debug)] +pub struct GetBaseBackupRequest { + /// The LSN to fetch a base backup at. + pub read_lsn: ReadLsn, + /// If true, logical replication slots will not be created. + pub replica: bool, +} + +impl TryFrom for GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + replica: pb.replica, + }) + } +} + +impl TryFrom for proto::GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(request: GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + replica: request.replica, + }) + } +} + +pub type GetBaseBackupResponseChunk = Bytes; + +impl TryFrom for GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupResponseChunk) -> Result { + if pb.chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(pb.chunk) + } +} + +impl TryFrom for proto::GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(chunk: GetBaseBackupResponseChunk) -> Result { + if chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(Self { chunk }) + } +} + +/// Requests the size of a database, as # of bytes. Only valid on shard 0, other shards will error. +#[derive(Clone, Copy, Debug)] +pub struct GetDbSizeRequest { + pub read_lsn: ReadLsn, + pub db_oid: Oid, +} + +impl TryFrom for GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + db_oid: pb.db_oid, + }) + } +} + +impl TryFrom for proto::GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + db_oid: request.db_oid, + }) + } +} + +pub type GetDbSizeResponse = u64; + +impl From for GetDbSizeResponse { + fn from(pb: proto::GetDbSizeResponse) -> Self { + pb.num_bytes + } +} + +impl From for proto::GetDbSizeResponse { + fn from(num_bytes: GetDbSizeResponse) -> Self { + Self { num_bytes } + } +} + +/// Requests one or more pages. #[derive(Clone, Debug)] +pub struct GetPageRequest { + /// A request ID. Will be included in the response. Should be unique for in-flight requests on + /// the stream. + pub request_id: RequestID, + /// The request class. + pub request_class: GetPageClass, + /// The LSN to read at. + pub read_lsn: ReadLsn, + /// The relation to read from. + pub rel: RelTag, + /// Page numbers to read. Must belong to the remote shard. + /// + /// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access + /// costs and parallelizing them. This may increase the latency of any individual request, but + /// improves the overall latency and throughput of the batch as a whole. + pub block_numbers: SmallVec<[u32; 1]>, +} + +impl TryFrom for GetPageRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetPageRequest) -> Result { + if pb.block_number.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: pb.request_id, + request_class: pb.request_class.into(), + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + block_numbers: pb.block_number.into(), + }) + } +} + +impl TryFrom for proto::GetPageRequest { + type Error = ProtocolError; + + fn try_from(request: GetPageRequest) -> Result { + if request.block_numbers.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: request.request_id, + request_class: request.request_class.into(), + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + block_number: request.block_numbers.into_vec(), + }) + } +} + +/// A GetPage request ID. +pub type RequestID = u64; + +/// A GetPage request class. +#[derive(Clone, Copy, Debug)] +pub enum GetPageClass { + /// Unknown class. For backwards compatibility: used when an older client version sends a class + /// that a newer server version has removed. + Unknown, + /// A normal request. This is the default. + Normal, + /// A prefetch request. NB: can only be classified on pg < 18. + Prefetch, + /// A background request (e.g. vacuum). + Background, +} + +impl From for GetPageClass { + fn from(pb: proto::GetPageClass) -> Self { + match pb { + proto::GetPageClass::Unknown => Self::Unknown, + proto::GetPageClass::Normal => Self::Normal, + proto::GetPageClass::Prefetch => Self::Prefetch, + proto::GetPageClass::Background => Self::Background, + } + } +} + +impl From for GetPageClass { + fn from(class: i32) -> Self { + proto::GetPageClass::try_from(class) + .unwrap_or(proto::GetPageClass::Unknown) + .into() + } +} + +impl From for proto::GetPageClass { + fn from(class: GetPageClass) -> Self { + match class { + GetPageClass::Unknown => Self::Unknown, + GetPageClass::Normal => Self::Normal, + GetPageClass::Prefetch => Self::Prefetch, + GetPageClass::Background => Self::Background, + } + } +} + +impl From for i32 { + fn from(class: GetPageClass) -> Self { + proto::GetPageClass::from(class).into() + } +} + +/// A GetPage response. +/// +/// A batch response will contain all of the requested pages. We could eagerly emit individual pages +/// as soon as they are ready, but on a readv() Postgres holds buffer pool locks on all pages in the +/// batch and we'll only return once the entire batch is ready, so no one can make use of the +/// individual pages. +#[derive(Clone, Debug)] +pub struct GetPageResponse { + /// The original request's ID. + pub request_id: RequestID, + /// The response status code. + pub status_code: GetPageStatusCode, + /// A string describing the status, if any. + pub reason: Option, + /// The 8KB page images, in the same order as the request. Empty if status != OK. + pub page_images: SmallVec<[Bytes; 1]>, +} + +impl From for GetPageResponse { + fn from(pb: proto::GetPageResponse) -> Self { + Self { + request_id: pb.request_id, + status_code: pb.status_code.into(), + reason: Some(pb.reason).filter(|r| !r.is_empty()), + page_images: pb.page_image.into(), + } + } +} + +impl From for proto::GetPageResponse { + fn from(response: GetPageResponse) -> Self { + Self { + request_id: response.request_id, + status_code: response.status_code.into(), + reason: response.reason.unwrap_or_default(), + page_image: response.page_images.into_vec(), + } + } +} + +/// A GetPage response status code. +/// +/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +/// (potentially shared by many backends), and a gRPC status response would terminate the stream so +/// we send GetPageResponse messages with these codes instead. +#[derive(Clone, Copy, Debug)] +pub enum GetPageStatusCode { + /// Unknown status. For forwards compatibility: used when an older client version receives a new + /// status code from a newer server version. + Unknown, + /// The request was successful. + Ok, + /// The page did not exist. The tenant/timeline/shard has already been validated during stream + /// setup. + NotFound, + /// The request was invalid. + InvalidRequest, + /// The request failed due to an internal server error. + InternalError, + /// The tenant is rate limited. Slow down and retry later. + SlowDown, +} + +impl From for GetPageStatusCode { + fn from(pb: proto::GetPageStatusCode) -> Self { + match pb { + proto::GetPageStatusCode::Unknown => Self::Unknown, + proto::GetPageStatusCode::Ok => Self::Ok, + proto::GetPageStatusCode::NotFound => Self::NotFound, + proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + proto::GetPageStatusCode::InternalError => Self::InternalError, + proto::GetPageStatusCode::SlowDown => Self::SlowDown, + } + } +} + +impl From for GetPageStatusCode { + fn from(status_code: i32) -> Self { + proto::GetPageStatusCode::try_from(status_code) + .unwrap_or(proto::GetPageStatusCode::Unknown) + .into() + } +} + +impl From for proto::GetPageStatusCode { + fn from(status_code: GetPageStatusCode) -> Self { + match status_code { + GetPageStatusCode::Unknown => Self::Unknown, + GetPageStatusCode::Ok => Self::Ok, + GetPageStatusCode::NotFound => Self::NotFound, + GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + GetPageStatusCode::InternalError => Self::InternalError, + GetPageStatusCode::SlowDown => Self::SlowDown, + } + } +} + +impl From for i32 { + fn from(status_code: GetPageStatusCode) -> Self { + proto::GetPageStatusCode::from(status_code).into() + } +} + +// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other +// shards will error. pub struct GetRelSizeRequest { pub read_lsn: ReadLsn, pub rel: RelTag, } -#[derive(Clone, Debug)] -pub struct GetRelSizeResponse { - pub num_blocks: u32, +impl TryFrom for GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(proto: proto::GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: proto + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } } -#[derive(Clone, Debug)] -pub struct GetPageRequest { - pub request_id: u64, - pub request_class: GetPageClass, - pub read_lsn: ReadLsn, - pub rel: RelTag, - pub block_number: Vec, +impl TryFrom for proto::GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + }) + } } -#[derive(Clone, Debug, PartialEq)] -pub enum GetPageClass { - Normal, - Prefetch, - Background, +pub type GetRelSizeResponse = u32; + +impl From for GetRelSizeResponse { + fn from(proto: proto::GetRelSizeResponse) -> Self { + proto.num_blocks + } } -#[derive(Clone, Debug)] -pub struct GetPageResponse { - pub request_id: u64, - pub status: GetPageStatus, - pub reason: String, - pub page_image: Vec, +impl From for proto::GetRelSizeResponse { + fn from(num_blocks: GetRelSizeResponse) -> Self { + Self { num_blocks } + } } -#[derive(Clone, Debug, PartialEq)] -pub enum GetPageStatus { - Ok, - NotFound, - Invalid, - SlowDown, -} - -#[derive(Clone, Debug)] -pub struct GetDbSizeRequest { - pub read_lsn: ReadLsn, - pub db_oid: u32, -} - -#[derive(Clone, Debug)] -pub struct GetDbSizeResponse { - pub num_bytes: u64, -} - -#[derive(Clone, Debug)] -pub struct GetBaseBackupRequest { - pub read_lsn: ReadLsn, - pub replica: bool, -} - -#[derive(Clone, Debug)] +/// Requests an SLRU segment. Only valid on shard 0, other shards will error. pub struct GetSlruSegmentRequest { pub read_lsn: ReadLsn, - pub kind: u8, // TODO: SlruKind + pub kind: SlruKind, pub segno: u32, } -//--- Conversions to/from the generated proto types - -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum ProtocolError { - #[error("the value for field `{0}` is invalid")] - InvalidValue(&'static str), - #[error("the required field `{0}` is missing ")] - Missing(&'static str), -} - -impl From for tonic::Status { - fn from(e: ProtocolError) -> Self { - match e { - ProtocolError::InvalidValue(_field) => tonic::Status::invalid_argument(e.to_string()), - ProtocolError::Missing(_field) => tonic::Status::invalid_argument(e.to_string()), - } - } -} - -impl From<&RelTag> for proto::RelTag { - fn from(value: &RelTag) -> proto::RelTag { - proto::RelTag { - spc_oid: value.spc_oid, - db_oid: value.db_oid, - rel_number: value.rel_number, - fork_number: value.fork_number as u32, - } - } -} -impl TryFrom<&proto::RelTag> for RelTag { +impl TryFrom for GetSlruSegmentRequest { type Error = ProtocolError; - fn try_from(value: &proto::RelTag) -> Result { - Ok(RelTag { - spc_oid: value.spc_oid, - db_oid: value.db_oid, - rel_number: value.rel_number, - fork_number: value - .fork_number - .try_into() - .or(Err(ProtocolError::InvalidValue("fork_number")))?, - }) - } -} - -impl From<&ReadLsn> for proto::ReadLsn { - fn from(value: &ReadLsn) -> proto::ReadLsn { - proto::ReadLsn { - request_lsn: value.request_lsn.into(), - not_modified_since_lsn: value.not_modified_since_lsn.into(), - } - } -} -impl From<&proto::ReadLsn> for ReadLsn { - fn from(value: &proto::ReadLsn) -> ReadLsn { - ReadLsn { - request_lsn: value.request_lsn.into(), - not_modified_since_lsn: value.not_modified_since_lsn.into(), - } - } -} - -impl From<&CheckRelExistsRequest> for proto::CheckRelExistsRequest { - fn from(value: &CheckRelExistsRequest) -> proto::CheckRelExistsRequest { - proto::CheckRelExistsRequest { - read_lsn: Some((&value.read_lsn).into()), - rel: Some((&value.rel).into()), - } - } -} -impl TryFrom<&proto::CheckRelExistsRequest> for CheckRelExistsRequest { - type Error = ProtocolError; - - fn try_from( - value: &proto::CheckRelExistsRequest, - ) -> Result { - Ok(CheckRelExistsRequest { - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - rel: (&value.rel.ok_or(ProtocolError::Missing("rel"))?).try_into()?, - }) - } -} - -impl From<&GetRelSizeRequest> for proto::GetRelSizeRequest { - fn from(value: &GetRelSizeRequest) -> proto::GetRelSizeRequest { - proto::GetRelSizeRequest { - read_lsn: Some((&value.read_lsn).into()), - rel: Some((&value.rel).into()), - } - } -} -impl TryFrom<&proto::GetRelSizeRequest> for GetRelSizeRequest { - type Error = ProtocolError; - - fn try_from(value: &proto::GetRelSizeRequest) -> Result { - Ok(GetRelSizeRequest { - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - rel: (&value.rel.ok_or(ProtocolError::Missing("rel"))?).try_into()?, - }) - } -} - -impl From<&GetPageRequest> for proto::GetPageRequest { - fn from(value: &GetPageRequest) -> proto::GetPageRequest { - proto::GetPageRequest { - request_id: value.request_id, - request_class: match value.request_class { - GetPageClass::Normal => proto::GetPageClass::Normal as i32, - GetPageClass::Prefetch => proto::GetPageClass::Prefetch as i32, - GetPageClass::Background => proto::GetPageClass::Background as i32, - }, - read_lsn: Some((&value.read_lsn).into()), - rel: Some((&value.rel).into()), - block_number: value.block_number.clone(), - } - } -} -impl TryFrom<&proto::GetPageRequest> for GetPageRequest { - type Error = ProtocolError; - - fn try_from(value: &proto::GetPageRequest) -> Result { - Ok(GetPageRequest { - request_id: value.request_id, - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - rel: (&value.rel.ok_or(ProtocolError::Missing("rel"))?).try_into()?, - block_number: value.block_number.clone(), - request_class: proto::GetPageClass::try_from(value.request_class) - .unwrap_or(proto::GetPageClass::Unknown) + fn try_from(pb: proto::GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? .try_into()?, + kind: u8::try_from(pb.kind) + .ok() + .and_then(SlruKind::from_repr) + .ok_or_else(|| ProtocolError::invalid("slru_kind", pb.kind))?, + segno: pb.segno, }) } } -impl TryFrom for GetPageClass { +impl TryFrom for proto::GetSlruSegmentRequest { type Error = ProtocolError; - fn try_from(value: proto::GetPageClass) -> Result { - match value { - proto::GetPageClass::Unknown => Err(ProtocolError::InvalidValue("class")), - proto::GetPageClass::Normal => Ok(GetPageClass::Normal), - proto::GetPageClass::Prefetch => Ok(GetPageClass::Prefetch), - proto::GetPageClass::Background => Ok(GetPageClass::Background), - } - } -} - -impl From for proto::GetPageClass { - fn from(value: GetPageClass) -> proto::GetPageClass { - match value { - GetPageClass::Normal => proto::GetPageClass::Normal, - GetPageClass::Prefetch => proto::GetPageClass::Prefetch, - GetPageClass::Background => proto::GetPageClass::Background, - } - } -} - -impl TryFrom for GetPageResponse { - type Error = ProtocolError; - - fn try_from(value: proto::GetPageResponse) -> Result { - Ok(GetPageResponse { - request_id: value.request_id, - status: proto::GetPageStatus::try_from(value.status) - .unwrap_or(proto::GetPageStatus::Unknown) - .try_into()?, - reason: value.reason, - page_image: value.page_image, + fn try_from(request: GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + kind: request.kind as u32, + segno: request.segno, }) } } -impl TryFrom for GetPageStatus { +pub type GetSlruSegmentResponse = Bytes; + +impl TryFrom for GetSlruSegmentResponse { type Error = ProtocolError; - fn try_from(value: proto::GetPageStatus) -> Result { - match value { - // Error on unknknown status -- we don't want to make any assumptions here. - // - // NB: this means that new statuses can only be used after all computes - // have been updated to understand them. Do something else instead? - proto::GetPageStatus::Unknown => Err(ProtocolError::InvalidValue("status")), - proto::GetPageStatus::Ok => Ok(GetPageStatus::Ok), - proto::GetPageStatus::NotFound => Ok(GetPageStatus::NotFound), - proto::GetPageStatus::Invalid => Ok(GetPageStatus::Invalid), - proto::GetPageStatus::SlowDown => Ok(GetPageStatus::SlowDown), + fn try_from(pb: proto::GetSlruSegmentResponse) -> Result { + if pb.segment.is_empty() { + return Err(ProtocolError::Missing("segment")); } + Ok(pb.segment) } } -impl From for proto::GetPageStatus { - fn from(value: GetPageStatus) -> proto::GetPageStatus { - match value { - GetPageStatus::Ok => proto::GetPageStatus::Ok, - GetPageStatus::NotFound => proto::GetPageStatus::NotFound, - GetPageStatus::Invalid => proto::GetPageStatus::Invalid, - GetPageStatus::SlowDown => proto::GetPageStatus::SlowDown, - } - } -} - -impl From<&GetDbSizeRequest> for proto::GetDbSizeRequest { - fn from(value: &GetDbSizeRequest) -> proto::GetDbSizeRequest { - proto::GetDbSizeRequest { - read_lsn: Some((&value.read_lsn).into()), - db_oid: value.db_oid, - } - } -} - -impl TryFrom<&proto::GetDbSizeRequest> for GetDbSizeRequest { +impl TryFrom for proto::GetSlruSegmentResponse { type Error = ProtocolError; - fn try_from(value: &proto::GetDbSizeRequest) -> Result { - Ok(GetDbSizeRequest { - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - db_oid: value.db_oid, - }) - } -} - -impl From<&GetBaseBackupRequest> for proto::GetBaseBackupRequest { - fn from(value: &GetBaseBackupRequest) -> proto::GetBaseBackupRequest { - proto::GetBaseBackupRequest { - read_lsn: Some((&value.read_lsn).into()), - replica: value.replica, + fn try_from(segment: GetSlruSegmentResponse) -> Result { + // TODO: can a segment legitimately be empty? + if segment.is_empty() { + return Err(ProtocolError::Missing("segment")); } + Ok(Self { segment }) } } -impl TryFrom<&proto::GetBaseBackupRequest> for GetBaseBackupRequest { - type Error = ProtocolError; - - fn try_from( - value: &proto::GetBaseBackupRequest, - ) -> Result { - Ok(GetBaseBackupRequest { - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - replica: value.replica, - }) - } -} - -impl TryFrom<&proto::GetSlruSegmentRequest> for GetSlruSegmentRequest { - type Error = ProtocolError; - - fn try_from(value: &proto::GetSlruSegmentRequest) -> Result { - Ok(GetSlruSegmentRequest { - read_lsn: (&value.read_lsn.ok_or(ProtocolError::Missing("read_lsn"))?).into(), - kind: value - .kind - .try_into() - .or(Err(ProtocolError::InvalidValue("kind")))?, - segno: value.segno, - }) - } -} +// SlruKind is defined in pageserver_api::reltag. +pub type SlruKind = pageserver_api::reltag::SlruKind; diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index 4469e1b755..4095c4818c 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true camino.workspace = true clap.workspace = true futures.workspace = true @@ -24,12 +25,12 @@ tokio-stream.workspace = true tokio-util.workspace = true axum.workspace = true http.workspace = true - metrics.workspace = true +tonic.workspace = true pageserver_client.workspace = true pageserver_client_grpc.workspace = true -pageserver_page_api.workspace = true pageserver_api.workspace = true +pageserver_page_api.workspace = true utils = { path = "../../libs/utils/" } workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/pageserver/pagebench/src/cmd/basebackup.rs b/pageserver/pagebench/src/cmd/basebackup.rs index af337d0c03..676f157e69 100644 --- a/pageserver/pagebench/src/cmd/basebackup.rs +++ b/pageserver/pagebench/src/cmd/basebackup.rs @@ -9,7 +9,7 @@ use anyhow::Context; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api::ForceAwaitLogicalSize; use pageserver_client::page_service::BasebackupRequest; -use pageserver_page_api::model::{GetBaseBackupRequest, ReadLsn}; +use pageserver_page_api::{GetBaseBackupRequest, ReadLsn}; use rand::prelude::*; use tokio::sync::Barrier; @@ -319,10 +319,10 @@ async fn client_grpc( info!("starting get_base_backup"); let mut basebackup_stream = client .get_base_backup( - &GetBaseBackupRequest { + GetBaseBackupRequest { read_lsn: ReadLsn { request_lsn: lsn, - not_modified_since_lsn: lsn, + not_modified_since_lsn: Some(lsn), }, replica: false, }, diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 44874f2cf2..93ff6c8b64 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -7,14 +7,17 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context; +use async_trait::async_trait; use camino::Utf8PathBuf; use futures::StreamExt; use futures::stream::FuturesOrdered; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; -use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::models::{ + PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest, +}; use pageserver_api::shard::TenantShardId; -use pageserver_page_api::model::{GetPageClass, GetPageResponse, GetPageStatus}; +use pageserver_page_api::proto; use rand::prelude::*; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -38,6 +41,12 @@ use metrics::{Encoder, TextEncoder}; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; +#[derive(clap::ValueEnum, Clone, Debug)] +enum Protocol { + Libpq, + Grpc, +} + /// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace. #[derive(clap::Parser)] pub(crate) struct Args { @@ -55,6 +64,8 @@ pub(crate) struct Args { num_clients: NonZeroUsize, #[clap(long)] runtime: Option, + #[clap(long, value_enum, default_value = "libpq")] + protocol: Protocol, /// Each client sends requests at the given rate. /// /// If a request takes too long and we should be issuing a new request already, @@ -399,16 +410,20 @@ async fn main_impl( let new_value = new_metrics.clone(); Box::pin(async move { - if args.grpc_stream { - client_grpc_stream(args, worker_id, ss, cancel, rps_period, ranges, weights).await - } else if args.grpc { - client_grpc( - args, worker_id, new_value, ss, cancel, rps_period, ranges, weights, - ) - .await - } else { - client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await - } + let client: Box = match args.protocol { + Protocol::Libpq => Box::new( + LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + + Protocol::Grpc => Box::new( + GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline) + .await + .unwrap(), + ), + }; + run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) }; @@ -460,23 +475,15 @@ async fn main_impl( anyhow::Ok(()) } -async fn client_libpq( +async fn run_worker( args: &Args, - worker_id: WorkerId, + mut client: Box, shared_state: Arc, cancel: CancellationToken, rps_period: Option, ranges: Vec, weights: rand::distributions::weighted::WeightedIndex, ) { - let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone()) - .await - .unwrap(); - let mut client = client - .pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id) - .await - .unwrap(); - shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); let mut ticks_processed = 0; @@ -520,12 +527,12 @@ async fn client_libpq( blkno: block_no, } }; - client.getpage_send(req).await.unwrap(); + client.send_get_page(req).await.unwrap(); inflight.push_back(start); } let start = inflight.pop_front().unwrap(); - client.getpage_recv().await.unwrap(); + client.recv_get_page().await.unwrap(); let end = Instant::now(); shared_state.live_stats.request_done(); ticks_processed += 1; @@ -548,228 +555,103 @@ async fn client_libpq( } } -#[allow(clippy::too_many_arguments)] -async fn client_grpc( - args: &Args, - worker_id: WorkerId, - client_metrics: Arc, - shared_state: Arc, - cancel: CancellationToken, - rps_period: Option, - ranges: Vec, - weights: rand::distributions::weighted::WeightedIndex, -) { - let shard_map = HashMap::from([( - ShardIndex::unsharded(), - args.page_service_connstring.clone(), - )]); - let options = pageserver_client_grpc::ClientCacheOptions { - max_consumers: args.pool_max_consumers.get(), - error_threshold: args.pool_error_threshold.get(), - connect_timeout: Duration::from_millis(args.pool_connect_timeout.get() as u64), - connect_backoff: Duration::from_millis(args.pool_connect_backoff.get() as u64), - max_idle_duration: Duration::from_millis(args.pool_max_idle_duration.get() as u64), - max_delay_ms: args.max_delay_ms as u64, - drop_rate: (args.percent_drops as f64) / 100.0, - hang_rate: (args.percent_hangs as f64) / 100.0, - }; - let client = pageserver_client_grpc::PageserverClient::new_with_config( - &worker_id.timeline.tenant_id.to_string(), - &worker_id.timeline.timeline_id.to_string(), - &None, - shard_map, - options, - Some(client_metrics.clone()), - ); +/// A benchmark client, to allow switching out the transport protocol. +/// +/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could +/// return a future that resolves when the response is received, but we don't really need it. +#[async_trait] +trait Client: Send { + /// Sends an asynchronous GetPage request to the pageserver. + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>; - let client = Arc::new(client); + /// Receives the next GetPage response from the pageserver. + async fn recv_get_page(&mut self) -> anyhow::Result; +} - shared_state.start_work_barrier.wait().await; - let client_start = Instant::now(); - let mut ticks_processed = 0; - let mut inflight = FuturesOrdered::new(); - while !cancel.is_cancelled() { - // Detect if a request took longer than the RPS rate - if let Some(period) = &rps_period { - let periods_passed_until_now = - usize::try_from(client_start.elapsed().as_micros() / period.as_micros()).unwrap(); +/// A libpq-based Pageserver client. +struct LibpqClient { + inner: pageserver_client::page_service::PagestreamClient, +} - if periods_passed_until_now > ticks_processed { - shared_state - .live_stats - .missed((periods_passed_until_now - ticks_processed) as u64); - } - ticks_processed = periods_passed_until_now; - } - - while inflight.len() < args.queue_depth.get() { - let start = Instant::now(); - let req = { - let mut rng = rand::thread_rng(); - let r = &ranges[weights.sample(&mut rng)]; - let key: i128 = rng.gen_range(r.start..r.end); - let key = Key::from_i128(key); - assert!(key.is_rel_block_key()); - let (rel_tag, block_no) = key - .to_rel_block() - .expect("we filter non-rel-block keys out above"); - pageserver_page_api::model::GetPageRequest { - request_id: 0, // TODO - request_class: GetPageClass::Normal, - read_lsn: pageserver_page_api::model::ReadLsn { - request_lsn: if rng.gen_bool(args.req_latest_probability) { - Lsn::MAX - } else { - r.timeline_lsn - }, - not_modified_since_lsn: r.timeline_lsn, - }, - rel: pageserver_page_api::model::RelTag { - spc_oid: rel_tag.spcnode, - db_oid: rel_tag.dbnode, - rel_number: rel_tag.relnode, - fork_number: rel_tag.forknum, - }, - block_number: vec![block_no], - } - }; - let client_clone = client.clone(); - let getpage_fut = async move { - let result = client_clone.get_page(&req).await; - (start, result) - }; - inflight.push_back(getpage_fut); - } - - let (start, result) = inflight.next().await.unwrap(); - result.expect("getpage request should succeed"); - let end = Instant::now(); - shared_state.live_stats.request_done(); - ticks_processed += 1; - STATS.with(|stats| { - stats - .borrow() - .lock() - .unwrap() - .observe(end.duration_since(start)) - .unwrap(); - }); - - if let Some(period) = &rps_period { - let next_at = client_start - + Duration::from_micros( - (ticks_processed) as u64 * u64::try_from(period.as_micros()).unwrap(), - ); - tokio::time::sleep_until(next_at.into()).await; - } +impl LibpqClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let inner = pageserver_client::page_service::Client::new(connstring) + .await? + .pagestream(ttid.tenant_id, ttid.timeline_id) + .await?; + Ok(Self { inner }) } } -async fn client_grpc_stream( - args: &Args, - worker_id: WorkerId, - shared_state: Arc, - cancel: CancellationToken, - rps_period: Option, - ranges: Vec, - weights: rand::distributions::weighted::WeightedIndex, -) { - let shard_map = HashMap::from([( - ShardIndex::unsharded(), - args.page_service_connstring.clone(), - )]); - let client = pageserver_client_grpc::PageserverClient::new( - &worker_id.timeline.tenant_id.to_string(), - &worker_id.timeline.timeline_id.to_string(), - &None, - shard_map, - ); +#[async_trait] +impl Client for LibpqClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + self.inner.getpage_send(req).await + } - let (request_tx, request_rx) = tokio::sync::mpsc::channel(1); - let request_stream = tokio_stream::wrappers::ReceiverStream::new(request_rx); - let mut response_stream = client.get_pages(request_stream).await.unwrap().into_inner(); - - shared_state.start_work_barrier.wait().await; - let client_start = Instant::now(); - let mut ticks_processed = 0; - let mut inflight = VecDeque::new(); - - while !cancel.is_cancelled() { - // Detect if a request took longer than the RPS rate - if let Some(period) = &rps_period { - let periods_passed_until_now = - usize::try_from(client_start.elapsed().as_micros() / period.as_micros()).unwrap(); - - if periods_passed_until_now > ticks_processed { - shared_state - .live_stats - .missed((periods_passed_until_now - ticks_processed) as u64); - } - ticks_processed = periods_passed_until_now; - } - - // Send requests until the queue depth is reached - // TODO: use batching - while inflight.len() < args.queue_depth.get() { - let start = Instant::now(); - let req = { - let mut rng = rand::thread_rng(); - let r = &ranges[weights.sample(&mut rng)]; - let key: i128 = rng.gen_range(r.start..r.end); - let key = Key::from_i128(key); - assert!(key.is_rel_block_key()); - let (rel_tag, block_no) = key - .to_rel_block() - .expect("we filter non-rel-block keys out above"); - pageserver_page_api::model::GetPageRequest { - request_id: 0, // TODO - request_class: GetPageClass::Normal, - read_lsn: pageserver_page_api::model::ReadLsn { - request_lsn: if rng.gen_bool(args.req_latest_probability) { - Lsn::MAX - } else { - r.timeline_lsn - }, - not_modified_since_lsn: r.timeline_lsn, - }, - rel: pageserver_page_api::model::RelTag { - spc_oid: rel_tag.spcnode, - db_oid: rel_tag.dbnode, - rel_number: rel_tag.relnode, - fork_number: rel_tag.forknum, - }, - block_number: vec![block_no], - } - }; - request_tx.send((&req).into()).await.unwrap(); - inflight.push_back(start); - } - - // Receive responses for the inflight requests - if let Some(response) = response_stream.next().await { - let response: GetPageResponse = response.unwrap().try_into().unwrap(); - assert_eq!(response.status, GetPageStatus::Ok); - let start = inflight.pop_front().unwrap(); - let end = Instant::now(); - shared_state.live_stats.request_done(); - ticks_processed += 1; - STATS.with(|stats| { - stats - .borrow() - .lock() - .unwrap() - .observe(end.duration_since(start)) - .unwrap(); - }); - } - - // Enforce RPS limit if specified - if let Some(period) = &rps_period { - let next_at = client_start - + Duration::from_micros( - (ticks_processed) as u64 * u64::try_from(period.as_micros()).unwrap(), - ); - tokio::time::sleep_until(next_at.into()).await; - } + async fn recv_get_page(&mut self) -> anyhow::Result { + self.inner.getpage_recv().await + } +} + +/// A gRPC client using the raw, no-frills gRPC client. +struct GrpcClient { + req_tx: tokio::sync::mpsc::Sender, + resp_rx: tonic::Streaming, +} + +impl GrpcClient { + async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { + let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?; + + // The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the + // benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are + // buffered by Tonic and the OS too. + let (req_tx, req_rx) = tokio::sync::mpsc::channel(1); + let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx); + let mut req = tonic::Request::new(req_stream); + let metadata = req.metadata_mut(); + metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?); + metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?); + metadata.insert("neon-shard-id", "0000".try_into()?); + + let resp = client.get_pages(req).await?; + let resp_stream = resp.into_inner(); + + Ok(Self { + req_tx, + resp_rx: resp_stream, + }) + } +} + +#[async_trait] +impl Client for GrpcClient { + async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> { + let req = proto::GetPageRequest { + request_id: 0, + request_class: proto::GetPageClass::Normal as i32, + read_lsn: Some(proto::ReadLsn { + request_lsn: req.hdr.request_lsn.0, + not_modified_since_lsn: req.hdr.not_modified_since.0, + }), + rel: Some(req.rel.into()), + block_number: vec![req.blkno], + }; + self.req_tx.send(req).await?; + Ok(()) + } + + async fn recv_get_page(&mut self) -> anyhow::Result { + let resp = self.resp_rx.message().await?.unwrap(); + anyhow::ensure!( + resp.status_code == proto::GetPageStatusCode::Ok as i32, + "unexpected status code: {}", + resp.status_code + ); + Ok(PagestreamGetPageResponse { + page: resp.page_image[0].clone(), + req: PagestreamGetPageRequest::default(), // dummy + }) } } diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 2e6990dbc4..22fe501019 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -65,6 +65,30 @@ impl From for BasebackupError { } } +impl From for postgres_backend::QueryError { + fn from(err: BasebackupError) -> Self { + use postgres_backend::QueryError; + use pq_proto::framed::ConnectionError; + match err { + BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)), + BasebackupError::Server(err) => QueryError::Other(err), + BasebackupError::Shutdown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: BasebackupError) -> Self { + use tonic::Code; + let code = match &err { + BasebackupError::Client(_, _) => Code::Cancelled, + BasebackupError::Server(_) => Code::Internal, + BasebackupError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + /// Create basebackup with non-rel data in it. /// Only include relational data if 'full_backup' is true. /// @@ -252,7 +276,7 @@ where async fn flush(&mut self) -> Result<(), BasebackupError> { let nblocks = self.buf.len() / BLCKSZ as usize; let (kind, segno) = self.current_segment.take().unwrap(); - let segname = format!("{}/{:>04X}", kind.to_str(), segno); + let segname = format!("{kind}/{segno:>04X}"); let header = new_tar_header(&segname, self.buf.len() as u64)?; self.ar .append(&header, self.buf.as_slice()) diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 81cd339624..ae1ac37c65 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -17,11 +17,11 @@ use metrics::launch_timestamp::{LaunchTimestamp, set_launch_timestamp_metric}; use metrics::set_build_info_metric; use nix::sys::socket::{setsockopt, sockopt}; use pageserver::basebackup_cache::BasebackupCache; -use pageserver::compute_service; use pageserver::config::{PageServerConf, PageserverIdentity, ignored_fields}; use pageserver::controller_upcall_client::StorageControllerUpcallClient; use pageserver::deletion_queue::DeletionQueue; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; +use pageserver::feature_resolver::FeatureResolver; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; use pageserver::task_mgr::{ BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME, @@ -31,6 +31,7 @@ use pageserver::{ CancellableTask, ConsumptionMetricsTasks, HttpEndpointListener, HttpsEndpointListener, http, page_cache, task_mgr, virtual_file, }; +use pageserver::{compute_service, page_service}; use postgres_backend::AuthType; use remote_storage::GenericRemoteStorage; use tokio::time::Instant; @@ -389,23 +390,30 @@ fn start_pageserver( // We need to release the lock file only when the process exits. std::mem::forget(lock_file); - // Bind the HTTP and libpq ports early, so that if they are in use by some other - // process, we error out early. - let http_addr = &conf.listen_http_addr; - info!("Starting pageserver http handler on {http_addr}"); - let http_listener = tcp_listener::bind(http_addr)?; + // Bind the HTTP, libpq, and gRPC ports early, to error out if they are + // already in use. + info!( + "Starting pageserver http handler on {} with auth {:#?}", + conf.listen_http_addr, conf.http_auth_type + ); + let http_listener = tcp_listener::bind(&conf.listen_http_addr)?; let https_listener = match conf.listen_https_addr.as_ref() { Some(https_addr) => { - info!("Starting pageserver https handler on {https_addr}"); + info!( + "Starting pageserver https handler on {https_addr} with auth {:#?}", + conf.http_auth_type + ); Some(tcp_listener::bind(https_addr)?) } None => None, }; - let pg_addr = &conf.listen_pg_addr; - info!("Starting pageserver pg protocol handler on {pg_addr}"); - let pageserver_listener = tcp_listener::bind(pg_addr)?; + info!( + "Starting pageserver pg protocol handler on {} with auth {:#?}", + conf.listen_pg_addr, conf.pg_auth_type, + ); + let pageserver_listener = tcp_listener::bind(&conf.listen_pg_addr)?; // Enable SO_KEEPALIVE on the socket, to detect dead connections faster. // These are configured via net.ipv4.tcp_keepalive_* sysctls. @@ -414,6 +422,15 @@ fn start_pageserver( // support enabling keepalives while using the default OS sysctls. setsockopt(&pageserver_listener, sockopt::KeepAlive, &true)?; + let mut grpc_listener = None; + if let Some(grpc_addr) = &conf.listen_grpc_addr { + info!( + "Starting pageserver gRPC handler on {grpc_addr} with auth {:#?}", + conf.grpc_auth_type + ); + grpc_listener = Some(tcp_listener::bind(grpc_addr).map_err(|e| anyhow!("{e}"))?); + } + // Launch broker client // The storage_broker::connect call needs to happen inside a tokio runtime thread. let broker_client = WALRECEIVER_RUNTIME @@ -441,7 +458,8 @@ fn start_pageserver( // Initialize authentication for incoming connections let http_auth; let pg_auth; - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + let grpc_auth; + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) { // unwrap is ok because check is performed when creating config, so path is set and exists let key_path = conf.auth_validation_public_key_path.as_ref().unwrap(); info!("Loading public key(s) for verifying JWT tokens from {key_path:?}"); @@ -449,20 +467,23 @@ fn start_pageserver( let jwt_auth = JwtAuth::from_key_path(key_path)?; let auth: Arc = Arc::new(SwappableJwtAuth::new(jwt_auth)); - http_auth = match &conf.http_auth_type { + http_auth = match conf.http_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth.clone()), }; - pg_auth = match &conf.pg_auth_type { + pg_auth = match conf.pg_auth_type { + AuthType::Trust => None, + AuthType::NeonJWT => Some(auth.clone()), + }; + grpc_auth = match conf.grpc_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth), }; } else { http_auth = None; pg_auth = None; + grpc_auth = None; } - info!("Using auth for http API: {:#?}", conf.http_auth_type); - info!("Using auth for pg connections: {:#?}", conf.pg_auth_type); let tls_server_config = if conf.listen_https_addr.is_some() || conf.enable_tls_page_service_api { @@ -503,6 +524,12 @@ fn start_pageserver( // Set up remote storage client let remote_storage = BACKGROUND_RUNTIME.block_on(create_remote_storage_client(conf))?; + let feature_resolver = create_feature_resolver( + conf, + shutdown_pageserver.clone(), + BACKGROUND_RUNTIME.handle(), + )?; + // Set up deletion queue let (deletion_queue, deletion_workers) = DeletionQueue::new( remote_storage.clone(), @@ -556,6 +583,7 @@ fn start_pageserver( deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, }, order, shutdown_pageserver.clone(), @@ -780,6 +808,22 @@ fn start_pageserver( basebackup_cache, ); + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for + // each stream/request. + // + // TODO: this uses a separate Tokio runtime for the page service. If we want + // other gRPC services, they will need their own port and runtime. Is this + // necessary? + let mut page_service_grpc = None; + if let Some(grpc_listener) = grpc_listener { + page_service_grpc = Some(page_service::spawn_grpc( + tenant_manager.clone(), + grpc_auth, + otel_guard.as_ref().map(|g| g.dispatch.clone()), + grpc_listener, + )?); + } + // All started up! Now just sit and wait for shutdown signal. BACKGROUND_RUNTIME.block_on(async move { let signal_token = CancellationToken::new(); @@ -798,6 +842,7 @@ fn start_pageserver( http_endpoint_listener, https_endpoint_listener, compute_service, + page_service_grpc, consumption_metrics_tasks, disk_usage_eviction_task, &tenant_manager, @@ -811,6 +856,14 @@ fn start_pageserver( }) } +fn create_feature_resolver( + conf: &'static PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, +) -> anyhow::Result { + FeatureResolver::spawn(conf, shutdown_pageserver, handle) +} + async fn create_remote_storage_client( conf: &'static PageServerConf, ) -> anyhow::Result { diff --git a/pageserver/src/compute_service_grpc.rs b/pageserver/src/compute_service_grpc.rs index 6e556ef04d..a5ecc91b09 100644 --- a/pageserver/src/compute_service_grpc.rs +++ b/pageserver/src/compute_service_grpc.rs @@ -33,7 +33,6 @@ use crate::tenant::storage_layer::IoConcurrency; use crate::tenant::timeline::WaitLsnTimeout; use async_stream::try_stream; use futures::Stream; -use pageserver_api::reltag::SlruKind; use tokio::io::{AsyncWriteExt, ReadHalf, SimplexStream}; use tokio::task::JoinHandle; use tokio_util::codec::{Decoder, FramedRead}; @@ -41,9 +40,9 @@ use tokio_util::sync::CancellationToken; use futures::stream::StreamExt; -use pageserver_page_api::model; use pageserver_page_api::proto::page_service_server::PageService; use pageserver_page_api::proto::page_service_server::PageServiceServer; +use pageserver_page_api::*; use anyhow::Context; use bytes::BytesMut; @@ -134,12 +133,12 @@ impl From for tonic::Status { } } -fn convert_reltag(value: &model::RelTag) -> pageserver_api::reltag::RelTag { +fn convert_reltag(value: &RelTag) -> pageserver_api::reltag::RelTag { pageserver_api::reltag::RelTag { - spcnode: value.spc_oid, - dbnode: value.db_oid, - relnode: value.rel_number, - forknum: value.fork_number, + spcnode: value.spcnode, + dbnode: value.dbnode, + relnode: value.relnode, + forknum: value.forknum, } } @@ -155,7 +154,7 @@ impl PageService for PageServiceService { ) -> std::result::Result, tonic::Status> { let ttid = self.extract_ttid(request.metadata())?; let shard = self.extract_shard(request.metadata())?; - let req: model::CheckRelExistsRequest = request.get_ref().try_into()?; + let req: CheckRelExistsRequest = request.into_inner().try_into()?; let rel = convert_reltag(&req.rel); let span = tracing::info_span!("check_rel_exists", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, rel = %rel, req_lsn = %req.read_lsn.request_lsn); @@ -167,7 +166,9 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, req.read_lsn.request_lsn, - req.read_lsn.not_modified_since_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) @@ -190,7 +191,7 @@ impl PageService for PageServiceService { ) -> std::result::Result, tonic::Status> { let ttid = self.extract_ttid(request.metadata())?; let shard = self.extract_shard(request.metadata())?; - let req: model::GetRelSizeRequest = request.get_ref().try_into()?; + let req: GetRelSizeRequest = request.into_inner().try_into()?; let rel = convert_reltag(&req.rel); let span = tracing::info_span!("get_rel_size", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, rel = %rel, req_lsn = %req.read_lsn.request_lsn); @@ -202,7 +203,9 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, req.read_lsn.request_lsn, - req.read_lsn.not_modified_since_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) @@ -239,7 +242,7 @@ impl PageService for PageServiceService { .enter() .or(Err(tonic::Status::unavailable("timeline is shutting down")))?; - let request: model::GetPageRequest = (&request).try_into()?; + let request: GetPageRequest = request.try_into()?; let rel = convert_reltag(&request.rel); let span = tracing::info_span!("get_pages", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, shard_id = %shard, rel = %rel, req_lsn = %request.read_lsn.request_lsn); @@ -248,7 +251,9 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, request.read_lsn.request_lsn, - request.read_lsn.not_modified_since_lsn, + request.read_lsn + .not_modified_since_lsn + .unwrap_or(request.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) @@ -257,8 +262,8 @@ impl PageService for PageServiceService { let io_concurrency = IoConcurrency::spawn_from_conf(conf.get_vectored_concurrent_io, guard); // TODO: use get_rel_page_at_lsn_batched - let mut page_images = Vec::with_capacity(request.block_number.len()); - for blkno in request.block_number { + let mut page_images = Vec::with_capacity(request.block_numbers.len()); + for blkno in request.block_numbers { let page_image = timeline .get_rel_page_at_lsn( rel, @@ -278,7 +283,7 @@ impl PageService for PageServiceService { let page_images = result?; yield proto::GetPageResponse { request_id: request.request_id, - status: proto::GetPageStatus::Ok as i32, + status_code: proto::GetPageStatusCode::Ok as i32, reason: "".to_string(), page_image: page_images, }; @@ -296,7 +301,7 @@ impl PageService for PageServiceService { ) -> Result, tonic::Status> { let ttid = self.extract_ttid(request.metadata())?; let shard = self.extract_shard(request.metadata())?; - let req: model::GetDbSizeRequest = request.get_ref().try_into()?; + let req: GetDbSizeRequest = request.into_inner().try_into()?; let span = tracing::info_span!("get_db_size", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, db_oid = %req.db_oid, req_lsn = %req.read_lsn.request_lsn); @@ -307,7 +312,9 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, req.read_lsn.request_lsn, - req.read_lsn.not_modified_since_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) @@ -331,7 +338,7 @@ impl PageService for PageServiceService { ) -> Result, tonic::Status> { let ttid = self.extract_ttid(request.metadata())?; let shard = self.extract_shard(request.metadata())?; - let req: model::GetBaseBackupRequest = request.get_ref().try_into()?; + let req: GetBaseBackupRequest = request.into_inner().try_into()?; let timeline = self.get_timeline(ttid, shard).await?; @@ -340,7 +347,9 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, req.read_lsn.request_lsn, - req.read_lsn.not_modified_since_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) @@ -471,7 +480,7 @@ impl PageService for PageServiceService { ) -> Result, tonic::Status> { let ttid = self.extract_ttid(request.metadata())?; let shard = self.extract_shard(request.metadata())?; - let req: model::GetSlruSegmentRequest = request.get_ref().try_into()?; + let req: GetSlruSegmentRequest = request.into_inner().try_into()?; let span = tracing::info_span!("get_slru_segment", tenant_id = %ttid.tenant_id, timeline_id = %ttid.timeline_id, kind = %req.kind, segno = %req.segno, req_lsn = %req.read_lsn.request_lsn); @@ -482,16 +491,16 @@ impl PageService for PageServiceService { let lsn = Self::wait_or_get_last_lsn( &timeline, req.read_lsn.request_lsn, - req.read_lsn.not_modified_since_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, &ctx, ) .await?; - let kind = SlruKind::from_repr(req.kind) - .ok_or(tonic::Status::from_error("invalid SLRU kind".into()))?; let segment = timeline - .get_slru_segment(kind, req.segno, lsn, &ctx) + .get_slru_segment(req.kind, req.segno, lsn, &ctx) .await?; Ok(tonic::Response::new(proto::GetSlruSegmentResponse { diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index e8b3b7b3ab..89f7539722 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,7 +14,7 @@ use std::time::Duration; use anyhow::{Context, bail, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; -use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes}; +use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig}; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; @@ -58,11 +58,16 @@ pub struct PageServerConf { pub listen_http_addr: String, /// Example: 127.0.0.1:9899 pub listen_https_addr: Option, + /// If set, expose a gRPC API on this address. + /// Example: 127.0.0.1:51051 + /// + /// EXPERIMENTAL: this protocol is unstable and under active development. + pub listen_grpc_addr: Option, - /// Path to a file with certificate's private key for https API. + /// Path to a file with certificate's private key for https and gRPC API. /// Default: server.key pub ssl_key_file: Utf8PathBuf, - /// Path to a file with a X509 certificate for https API. + /// Path to a file with a X509 certificate for https and gRPC API. /// Default: server.crt pub ssl_cert_file: Utf8PathBuf, /// Period to reload certificate and private key from files. @@ -100,6 +105,8 @@ pub struct PageServerConf { pub http_auth_type: AuthType, /// authentication method for libpq connections from compute pub pg_auth_type: AuthType, + /// authentication method for gRPC connections from compute + pub grpc_auth_type: AuthType, /// Path to a file or directory containing public key(s) for verifying JWT tokens. /// Used for both mgmt and compute auth, if enabled. pub auth_validation_public_key_path: Option, @@ -231,6 +238,9 @@ pub struct PageServerConf { /// This is insecure and should only be used in development environments. pub dev_mode: bool, + /// PostHog integration config. + pub posthog_config: Option, + pub timeline_import_config: pageserver_api::config::TimelineImportConfig, pub basebackup_cache_config: Option, @@ -355,6 +365,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -369,6 +380,7 @@ impl PageServerConf { pg_distrib_dir, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage, broker_endpoint, @@ -412,6 +424,7 @@ impl PageServerConf { tracing, enable_tls_page_service_api, dev_mode, + posthog_config, timeline_import_config, basebackup_cache_config, } = config_toml; @@ -423,6 +436,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -435,6 +449,7 @@ impl PageServerConf { max_file_descriptors, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage_config: remote_storage, broker_endpoint, @@ -525,13 +540,16 @@ impl PageServerConf { } None => Vec::new(), }, + posthog_config, }; // ------------------------------------------------------------ // custom validation code that covers more than one field in isolation // ------------------------------------------------------------ - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { let auth_validation_public_key_path = conf .auth_validation_public_key_path .get_or_insert_with(|| workdir.join("auth_public_key.pem")); diff --git a/pageserver/src/disk_usage_eviction_task.rs b/pageserver/src/disk_usage_eviction_task.rs index 13252037e5..f13b3709f5 100644 --- a/pageserver/src/disk_usage_eviction_task.rs +++ b/pageserver/src/disk_usage_eviction_task.rs @@ -837,7 +837,30 @@ async fn collect_eviction_candidates( continue; } let info = tl.get_local_layers_for_disk_usage_eviction().await; - debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len()); + debug!( + tenant_id=%tl.tenant_shard_id.tenant_id, + shard_id=%tl.tenant_shard_id.shard_slug(), + timeline_id=%tl.timeline_id, + "timeline resident layers count: {}", info.resident_layers.len() + ); + + tenant_candidates.extend(info.resident_layers.into_iter()); + max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0)); + + if cancel.is_cancelled() { + return Ok(EvictionCandidates::Cancelled); + } + } + + // Also consider layers of timelines being imported for eviction + for tl in tenant.list_importing_timelines() { + let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await; + debug!( + tenant_id=%tl.timeline.tenant_shard_id.tenant_id, + shard_id=%tl.timeline.tenant_shard_id.shard_slug(), + timeline_id=%tl.timeline.timeline_id, + "timeline resident layers count: {}", info.resident_layers.len() + ); tenant_candidates.extend(info.resident_layers.into_iter()); max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0)); diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs new file mode 100644 index 0000000000..7e31b930d0 --- /dev/null +++ b/pageserver/src/feature_resolver.rs @@ -0,0 +1,104 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use posthog_client_lite::{ + FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, +}; +use tokio_util::sync::CancellationToken; +use utils::id::TenantId; + +use crate::config::PageServerConf; + +#[derive(Clone)] +pub struct FeatureResolver { + inner: Option>, +} + +impl FeatureResolver { + pub fn new_disabled() -> Self { + Self { inner: None } + } + + pub fn spawn( + conf: &PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, + ) -> anyhow::Result { + // DO NOT block in this function: make it return as fast as possible to avoid startup delays. + if let Some(posthog_config) = &conf.posthog_config { + let inner = FeatureResolverBackgroundLoop::new( + PostHogClientConfig { + server_api_key: posthog_config.server_api_key.clone(), + client_api_key: posthog_config.client_api_key.clone(), + project_id: posthog_config.project_id.clone(), + private_api_url: posthog_config.private_api_url.clone(), + public_api_url: posthog_config.public_api_url.clone(), + }, + shutdown_pageserver, + ); + let inner = Arc::new(inner); + // TODO: make this configurable + inner.clone().spawn(handle, Duration::from_secs(60)); + Ok(FeatureResolver { inner: Some(inner) }) + } else { + Ok(FeatureResolver { inner: None }) + } + } + + /// Evaluate a multivariate feature flag. Currently, we do not support any properties. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_multivariate( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_multivariate( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } + + /// Evaluate a boolean feature flag. Currently, we do not support any properties. + /// + /// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result<(), PostHogEvaluationError> { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_boolean( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } + + pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result { + if let Some(inner) = &self.inner { + inner.feature_store().is_feature_flag_boolean(flag_key) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } +} diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index 7ea148971f..e8d1367d6c 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -353,6 +353,33 @@ paths: "200": description: OK + /v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible: + parameters: + - name: tenant_shard_id + in: path + required: true + schema: + type: string + - name: timeline_id + in: path + required: true + schema: + type: string + format: hex + put: + requestBody: + content: + application/json: + schema: + type: object + properties: + is_visible: + type: boolean + default: false + responses: + "200": + description: OK + /v1/tenant/{tenant_shard_id}/location_config: parameters: - name: tenant_shard_id @@ -626,6 +653,8 @@ paths: format: hex pg_version: type: integer + read_only: + type: boolean existing_initdb_timeline_id: type: string format: hex diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 0d6791cddd..1effa10404 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -370,6 +370,18 @@ impl From for ApiError { } } +impl From for ApiError { + fn from(err: crate::tenant::FinalizeTimelineImportError) -> ApiError { + use crate::tenant::FinalizeTimelineImportError::*; + match err { + ImportTaskStillRunning => { + ApiError::ResourceUnavailable("Import task still running".into()) + } + ShuttingDown => ApiError::ShuttingDown, + } + } +} + // Helper function to construct a TimelineInfo struct for a timeline async fn build_timeline_info( timeline: &Arc, @@ -572,6 +584,7 @@ async fn timeline_create_handler( TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn, + read_only: _, pg_version: _, } => tenant::CreateTimelineParams::Branch(tenant::CreateTimelineParamsBranch { new_timeline_id, @@ -3532,10 +3545,7 @@ async fn activate_post_import_handler( tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?; - tenant - .finalize_importing_timeline(timeline_id) - .await - .map_err(ApiError::InternalServerError)?; + tenant.finalize_importing_timeline(timeline_id).await?; match tenant.get_timeline(timeline_id, false) { Ok(_timeline) => { @@ -3653,6 +3663,46 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow Ok(()) } +async fn tenant_evaluate_feature_flag( + request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?; + check_permission(&request, Some(tenant_shard_id.tenant_id))?; + + let flag: String = must_parse_query_param(&request, "flag")?; + let as_type: String = must_parse_query_param(&request, "as")?; + + let state = get_state(&request); + + async { + let tenant = state + .tenant_manager + .get_attached_tenant_shard(tenant_shard_id)?; + if as_type == "boolean" { + let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); + let result = result.map(|_| true).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else if as_type == "multivariate" { + let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else { + // Auto infer the type of the feature flag. + let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?; + if is_boolean { + let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); + let result = result.map(|_| true).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } else { + let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); + json_response(StatusCode::OK, result) + } + } + } + .instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug())) + .await +} + /// Common functionality of all the HTTP API handlers. /// /// - Adds a tracing span to each request (by `request_span`) @@ -4029,5 +4079,8 @@ pub fn make_router( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import", |r| api_handler(r, activate_post_import_handler), ) + .get("/v1/tenant/:tenant_shard_id/feature_flag", |r| { + api_handler(r, tenant_evaluate_feature_flag) + }) .any(handler_404)) } diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 72405a0a84..458307df25 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -10,6 +10,7 @@ pub mod context; pub mod controller_upcall_client; pub mod deletion_queue; pub mod disk_usage_eviction_task; +pub mod feature_resolver; pub mod http; pub mod import_datadir; pub mod l0_flush; @@ -86,6 +87,7 @@ pub async fn shutdown_pageserver( http_listener: HttpEndpointListener, https_listener: Option, compute_service: compute_service::Listener, + grpc_task: Option, consumption_metrics_worker: ConsumptionMetricsTasks, disk_usage_eviction_task: Option, tenant_manager: &TenantManager, @@ -179,6 +181,16 @@ pub async fn shutdown_pageserver( ) .await; + // Shut down the gRPC server task, including request handlers. + if let Some(grpc_task) = grpc_task { + timed( + grpc_task.shutdown(), + "shutdown gRPC PageRequestHandler", + Duration::from_secs(3), + ) + .await; + } + // Shut down all the tenants. This flushes everything to disk and kills // the checkpoint and GC tasks. timed( diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3076c7f1d6..a9b2f1b7e0 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration { // // Roughly logarithmic scale. const STORAGE_IO_TIME_BUCKETS: &[f64] = &[ - 0.000030, // 30 usec - 0.001000, // 1000 usec - 0.030, // 30 ms - 1.000, // 1000 ms - 30.000, // 30000 ms + 0.00005, // 50us + 0.00006, // 60us + 0.00007, // 70us + 0.00008, // 80us + 0.00009, // 90us + 0.0001, // 100us + 0.000110, // 110us + 0.000120, // 120us + 0.000130, // 130us + 0.000140, // 140us + 0.000150, // 150us + 0.000160, // 160us + 0.000170, // 170us + 0.000180, // 180us + 0.000190, // 190us + 0.000200, // 200us + 0.000210, // 210us + 0.000220, // 220us + 0.000230, // 230us + 0.000240, // 240us + 0.000250, // 250us + 0.000300, // 300us + 0.000350, // 350us + 0.000400, // 400us + 0.000450, // 450us + 0.000500, // 500us + 0.000600, // 600us + 0.000700, // 700us + 0.000800, // 800us + 0.000900, // 900us + 0.001000, // 1ms + 0.002000, // 2ms + 0.003000, // 3ms + 0.004000, // 4ms + 0.005000, // 5ms + 0.01000, // 10ms + 0.02000, // 20ms + 0.05000, // 50ms ]; /// VirtualFile fs operation variants. @@ -2234,8 +2267,10 @@ impl BasebackupQueryTimeOngoingRecording<'_> { // If you want to change categorize of a specific error, also change it in `log_query_error`. let metric = match res { Ok(_) => &self.parent.ok, - Err(QueryError::Shutdown) => { - // Do not observe ok/err for shutdown + Err(QueryError::Shutdown) | Err(QueryError::Reconnect) => { + // Do not observe ok/err for shutdown/reconnect. + // Reconnect error might be raised when the operation is waiting for LSN and the tenant shutdown interrupts + // the operation. A reconnect error will be issued and the client will retry. return; } Err(QueryError::Disconnected(ConnectionError::Io(io_error))) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 7412750d65..0d62115467 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -1,17 +1,22 @@ //! The Page Service listens for client connections and serves their GetPage@LSN //! requests. +use std::any::Any; use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; -use anyhow::{Context, bail}; +use anyhow::{Context as _, anyhow, bail}; use async_compression::tokio::write::GzipEncoder; -use bytes::Buf; +use bytes::{Buf, BytesMut}; +use futures::future::BoxFuture; +use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; use once_cell::sync::OnceCell; @@ -29,6 +34,8 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api as page_api; +use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, }; @@ -36,25 +43,31 @@ use postgres_ffi::BLCKSZ; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; use pq_proto::framed::ConnectionError; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor}; +use smallvec::{SmallVec, smallvec}; use strum_macros::IntoStaticStr; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter}; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; -use utils::failpoint_support; -use utils::id::{TenantId, TimelineId}; +use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; +use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; -use utils::sync::gate::GateGuard; +use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; +use utils::{failpoint_support, span_record}; -use crate::PERF_TRACE_TARGET; use crate::auth::check_permission; -use crate::basebackup::BasebackupError; +use crate::basebackup::{self, BasebackupError}; use crate::basebackup_cache::BasebackupCache; +use crate::compute_service::page_service_conn_main; use crate::config::PageServerConf; -use crate::context::{PerfInstrumentFutureExt, RequestContext, RequestContextBuilder}; +use crate::context::{ + DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, +}; use crate::metrics::{ self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS, SmgrOpTimer, TimelineMetrics, @@ -64,13 +77,15 @@ use crate::span::{ debug_assert_current_span_has_tenant_and_timeline_id, debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id, }; +use crate::task_mgr::{self, COMPUTE_REQUEST_RUNTIME, TaskKind}; use crate::tenant::mgr::{ GetActiveTenantError, GetTenantError, ShardResolveResult, ShardSelector, TenantManager, }; use crate::tenant::storage_layer::IoConcurrency; -use crate::tenant::timeline::{self, WaitLsnError}; +use crate::tenant::timeline::handle::{Handle, HandleUpgradeError, WeakHandle}; +use crate::tenant::timeline::{self, WaitLsnError, WaitLsnTimeout, WaitLsnWaiter}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; -use crate::{basebackup, timed_after_cancellation}; +use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; /// How long we may wait for a [`crate::tenant::mgr::TenantSlot::InProgress`]` and/or a [`crate::tenant::TenantShard`] which /// is not yet in state [`TenantState::Active`]. @@ -81,11 +96,293 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000); /// Threshold at which to log slow GetPage requests. const LOG_SLOW_GETPAGE_THRESHOLD: Duration = Duration::from_secs(30); +/// The idle time before sending TCP keepalive probes for gRPC connections. The +/// interval and timeout between each probe is configured via sysctl. This +/// allows detecting dead connections sooner. +const GRPC_TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(60); + +/// Whether to enable TCP nodelay for gRPC connections. This disables Nagle's +/// algorithm, which can cause latency spikes for small messages. +const GRPC_TCP_NODELAY: bool = true; + +/// The interval between HTTP2 keepalive pings. This allows shutting down server +/// tasks when clients are unresponsive. +const GRPC_HTTP2_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); + +/// The timeout for HTTP2 keepalive pings. Should be <= GRPC_KEEPALIVE_INTERVAL. +const GRPC_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20); + +/// Number of concurrent gRPC streams per TCP connection. We expect something +/// like 8 GetPage streams per connections, plus any unary requests. +const GRPC_MAX_CONCURRENT_STREAMS: u32 = 256; + +/////////////////////////////////////////////////////////////////////////////// + +pub struct Listener { + cancel: CancellationToken, + /// Cancel the listener task through `listen_cancel` to shut down the listener + /// and get a handle on the existing connections. + task: JoinHandle, +} + +pub struct Connections { + cancel: CancellationToken, + tasks: tokio::task::JoinSet, + gate: Gate, +} + +pub fn spawn( + conf: &'static PageServerConf, + tenant_manager: Arc, + pg_auth: Option>, + perf_trace_dispatch: Option, + tcp_listener: tokio::net::TcpListener, + tls_config: Option>, + basebackup_cache: Arc, +) -> Listener { + let cancel = CancellationToken::new(); + let libpq_ctx = RequestContext::todo_child( + TaskKind::LibpqEndpointListener, + // listener task shouldn't need to download anything. (We will + // create a separate sub-contexts for each connection, with their + // own download behavior. This context is used only to listen and + // accept connections.) + DownloadBehavior::Error, + ); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "libpq listener", + libpq_listener_main( + conf, + tenant_manager, + pg_auth, + perf_trace_dispatch, + tcp_listener, + conf.pg_auth_type, + tls_config, + conf.page_service_pipelining.clone(), + basebackup_cache, + libpq_ctx, + cancel.clone(), + ) + .map(anyhow::Ok), + )); + + Listener { cancel, task } +} + +/// Spawns a gRPC server for the page service. +/// +/// TODO: move this onto GrpcPageServiceHandler::spawn(). +/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we +/// need to reimplement the TCP+TLS accept loop ourselves. +pub fn spawn_grpc( + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + listener: std::net::TcpListener, +) -> anyhow::Result { + let cancel = CancellationToken::new(); + let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch) + .detached_child(); + let gate = Gate::default(); + + // Set up the TCP socket. We take a preconfigured TcpListener to bind the + // port early during startup. + let incoming = { + let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std + listener.set_nonblocking(true)?; + tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?) + .with_nodelay(Some(GRPC_TCP_NODELAY)) + .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) + }; + + // Set up the gRPC server. + // + // TODO: consider tuning window sizes. + let mut server = tonic::transport::Server::builder() + .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) + .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) + .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); + + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. + let page_service_handler = GrpcPageServiceHandler { + tenant_manager, + ctx, + }; + + let observability_layer = ObservabilityLayer; + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + + let page_service = tower::ServiceBuilder::new() + // Create tracing span and record request start time. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + .service(proto::PageServiceServer::new(page_service_handler)); + let server = server.add_service(page_service); + + // Reflection service for use with e.g. grpcurl. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1()?; + let server = server.add_service(reflection_service); + + // Spawn server task. + let task_cancel = cancel.clone(); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "grpc listener", + async move { + let result = server + .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) + .await; + if result.is_ok() { + // TODO: revisit shutdown logic once page service is implemented. + gate.close().await; + } + result + }, + )); + + Ok(CancellableTask { task, cancel }) +} + +impl Listener { + pub async fn stop_accepting(self) -> Connections { + self.cancel.cancel(); + self.task + .await + .expect("unreachable: we wrap the listener task in task_mgr::exit_on_panic_or_error") + } +} +impl Connections { + pub(crate) async fn shutdown(self) { + let Self { + cancel, + mut tasks, + gate, + } = self; + cancel.cancel(); + while let Some(res) = tasks.join_next().await { + Self::handle_connection_completion(res); + } + gate.close().await; + } + + fn handle_connection_completion(res: Result, tokio::task::JoinError>) { + match res { + Ok(Ok(())) => {} + Ok(Err(e)) => error!("error in page_service connection task: {:?}", e), + Err(e) => error!("page_service connection task panicked: {:?}", e), + } + } +} + +/// +/// Main loop of the page service. +/// +/// Listens for connections, and launches a new handler task for each. +/// +/// Returns Ok(()) upon cancellation via `cancel`, returning the set of +/// open connections. +/// +#[allow(clippy::too_many_arguments)] +pub async fn libpq_listener_main( + conf: &'static PageServerConf, + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + listener: tokio::net::TcpListener, + auth_type: AuthType, + tls_config: Option>, + pipelining_config: PageServicePipeliningConfig, + basebackup_cache: Arc, + listener_ctx: RequestContext, + listener_cancel: CancellationToken, +) -> Connections { + let connections_cancel = CancellationToken::new(); + let connections_gate = Gate::default(); + let mut connection_handler_tasks = tokio::task::JoinSet::default(); + + loop { + let gate_guard = match connections_gate.enter() { + Ok(guard) => guard, + Err(_) => break, + }; + + let accepted = tokio::select! { + biased; + _ = listener_cancel.cancelled() => break, + next = connection_handler_tasks.join_next(), if !connection_handler_tasks.is_empty() => { + let res = next.expect("we dont poll while empty"); + Connections::handle_connection_completion(res); + continue; + } + accepted = listener.accept() => accepted, + }; + + match accepted { + Ok((socket, peer_addr)) => { + // Connection established. Spawn a new task to handle it. + debug!("accepted connection from {}", peer_addr); + let local_auth = auth.clone(); + let connection_ctx = RequestContextBuilder::from(&listener_ctx) + .task_kind(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch.clone()) + .detached_child(); + + let (dummy_tx, _) = tokio::sync::mpsc::channel(1); + + connection_handler_tasks.spawn(page_service_conn_main( + conf, + tenant_manager.clone(), + local_auth, + socket, + auth_type, + tls_config.clone(), + pipelining_config.clone(), + basebackup_cache.clone(), + connection_ctx, + connections_cancel.child_token(), + gate_guard, + dummy_tx, + )); + } + Err(err) => { + // accept() failed. Log the error, and loop back to retry on next connection. + error!("accept() failed: {:?}", err); + } + } + } + + debug!("page_service listener loop terminated"); + + Connections { + cancel: connections_cancel, + tasks: connection_handler_tasks, + gate: connections_gate, + } +} + type ConnectionHandlerResult = anyhow::Result<()>; /// Perf root spans start at the per-request level, after shard routing. /// This struct carries connection-level information to the root perf span definition. -#[derive(Clone)] +#[derive(Clone, Default)] struct ConnectionPerfSpanFields { peer_addr: String, application_name: Option, @@ -200,6 +497,11 @@ pub async fn libpq_page_service_conn_main( } } +/// Page service connection handler. +/// +/// TODO: for gRPC, this will be shared by all requests from all connections. +/// Decompose it into global state and per-connection/request state, and make +/// libpq-specific options (e.g. pipelining) separate. struct PageServerHandler { auth: Option>, claims: Option, @@ -249,7 +551,7 @@ impl TimelineHandles { tenant_id: TenantId, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result, GetActiveTimelineError> { + ) -> Result, GetActiveTimelineError> { if *self.wrapper.tenant_id.get_or_init(|| tenant_id) != tenant_id { return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::SwitchedTenant, @@ -416,6 +718,82 @@ enum PageStreamError { BadRequest(Cow<'static, str>), } +impl PageStreamError { + /// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status + /// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a + /// convenience method for use from a get_pages gRPC stream. + #[allow(clippy::result_large_err)] + fn into_get_page_response( + self, + request_id: page_api::RequestID, + ) -> Result { + use page_api::GetPageStatusCode; + use tonic::Code; + + // We dispatch to Into first, and then map it to a GetPageResponse. + let status: tonic::Status = self.into(); + let status_code = match status.code() { + // We shouldn't see an OK status here, because we're emitting an error. + Code::Ok => { + debug_assert_ne!(status.code(), Code::Ok); + return Err(tonic::Status::internal(format!( + "unexpected OK status: {status:?}", + ))); + } + + // These are per-request errors, returned as GetPageResponses. + Code::AlreadyExists => GetPageStatusCode::InvalidRequest, + Code::DataLoss => GetPageStatusCode::InternalError, + Code::FailedPrecondition => GetPageStatusCode::InvalidRequest, + Code::InvalidArgument => GetPageStatusCode::InvalidRequest, + Code::Internal => GetPageStatusCode::InternalError, + Code::NotFound => GetPageStatusCode::NotFound, + Code::OutOfRange => GetPageStatusCode::InvalidRequest, + Code::ResourceExhausted => GetPageStatusCode::SlowDown, + + // These should terminate the stream. + Code::Aborted => return Err(status), + Code::Cancelled => return Err(status), + Code::DeadlineExceeded => return Err(status), + Code::PermissionDenied => return Err(status), + Code::Unauthenticated => return Err(status), + Code::Unavailable => return Err(status), + Code::Unimplemented => return Err(status), + Code::Unknown => return Err(status), + }; + + Ok(page_api::GetPageResponse { + request_id, + status_code, + reason: Some(status.message().to_string()), + page_images: SmallVec::new(), + } + .into()) + } +} + +impl From for tonic::Status { + fn from(err: PageStreamError) -> Self { + use tonic::Code; + let message = err.to_string(); + let code = match err { + PageStreamError::Reconnect(_) => Code::Unavailable, + PageStreamError::Shutdown => Code::Unavailable, + PageStreamError::Read(err) => match err { + PageReconstructError::Cancelled => Code::Unavailable, + PageReconstructError::MissingKey(_) => Code::NotFound, + PageReconstructError::AncestorLsnTimeout(err) => tonic::Status::from(err).code(), + PageReconstructError::Other(_) => Code::Internal, + PageReconstructError::WalRedo(_) => Code::Internal, + }, + PageStreamError::LsnTimeout(err) => tonic::Status::from(err).code(), + PageStreamError::NotFound(_) => Code::NotFound, + PageStreamError::BadRequest(_) => Code::InvalidArgument, + }; + tonic::Status::new(code, message) + } +} + impl From for PageStreamError { fn from(value: PageReconstructError) -> Self { match value { @@ -476,6 +854,9 @@ struct BatchedGetPageRequest { timer: SmgrOpTimer, lsn_range: LsnRange, ctx: RequestContext, + // If the request is perf enabled, this contains a context + // with a perf span tracking the time spent waiting for the executor. + batch_wait_ctx: Option, } #[cfg(feature = "testing")] @@ -488,41 +869,42 @@ struct BatchedTestRequest { /// so that we don't keep the [`Timeline::gate`] open while the batch /// is being built up inside the [`spsc_fold`] (pagestream pipelining). #[derive(IntoStaticStr)] +#[allow(clippy::large_enum_variant)] enum BatchedFeMessage { Exists { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamExistsRequest, }, Nblocks { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamNblocksRequest, }, GetPage { span: Span, - shard: timeline::handle::WeakHandle, - pages: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + shard: WeakHandle, + pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, DbSize { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, timer: SmgrOpTimer, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, req: models::PagestreamGetSlruSegmentRequest, }, #[cfg(feature = "testing")] Test { span: Span, - shard: timeline::handle::WeakHandle, + shard: WeakHandle, requests: Vec, }, RespondError { @@ -771,26 +1153,6 @@ impl PageServerHandler { let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; - // TODO: turn in to async closure once available to avoid repeating received_at - async fn record_op_start_and_throttle( - shard: &timeline::handle::Handle, - op: metrics::SmgrQueryType, - received_at: Instant, - ) -> Result { - // It's important to start the smgr op metric recorder as early as possible - // so that the _started counters are incremented before we do - // any serious waiting, e.g., for throttle, batching, or actual request handling. - let mut timer = shard.query_metrics.start_smgr_op(op, received_at); - let now = Instant::now(); - timer.observe_throttle_start(now); - let throttled = tokio::select! { - res = shard.pagestream_throttle.throttle(1, now) => res, - _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), - }; - timer.observe_throttle_done(throttled); - Ok(timer) - } - let batched_msg = match neon_fe_msg { PagestreamFeMessage::Exists(req) => { let shard = timeline_handles @@ -798,7 +1160,7 @@ impl PageServerHandler { .await?; debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); let span = tracing::info_span!(parent: &parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelExists, received_at, @@ -816,7 +1178,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetRelSize, received_at, @@ -834,7 +1196,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetDbSize, received_at, @@ -852,7 +1214,7 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.hdr.request_lsn, shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetSlruSegment, received_at, @@ -977,7 +1339,7 @@ impl PageServerHandler { // request handler log messages contain the request-specific fields. let span = mkspan!(shard.tenant_shard_id.shard_slug()); - let timer = record_op_start_and_throttle( + let timer = Self::record_op_start_and_throttle( &shard, metrics::SmgrQueryType::GetPageAtLsn, received_at, @@ -1005,10 +1367,26 @@ impl PageServerHandler { } }; + let batch_wait_ctx = if ctx.has_perf_span() { + Some( + RequestContextBuilder::from(&ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "WAIT_EXECUTOR", + ) + }) + .attached_child(), + ) + } else { + None + }; + BatchedFeMessage::GetPage { span, shard: shard.downgrade(), - pages: smallvec::smallvec![BatchedGetPageRequest { + pages: smallvec![BatchedGetPageRequest { req, timer, lsn_range: LsnRange { @@ -1016,6 +1394,7 @@ impl PageServerHandler { request_lsn: req.hdr.request_lsn }, ctx, + batch_wait_ctx, }], // The executor grabs the batch when it becomes idle. // Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the @@ -1029,9 +1408,12 @@ impl PageServerHandler { .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; let span = tracing::info_span!(parent: &parent_span, "handle_test_request", shard_id = %shard.tenant_shard_id.shard_slug()); - let timer = - record_op_start_and_throttle(&shard, metrics::SmgrQueryType::Test, received_at) - .await?; + let timer = Self::record_op_start_and_throttle( + &shard, + metrics::SmgrQueryType::Test, + received_at, + ) + .await?; BatchedFeMessage::Test { span, shard: shard.downgrade(), @@ -1042,6 +1424,26 @@ impl PageServerHandler { Ok(Some(batched_msg)) } + /// Starts a SmgrOpTimer at received_at and throttles the request. + async fn record_op_start_and_throttle( + shard: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + // It's important to start the smgr op metric recorder as early as possible + // so that the _started counters are incremented before we do + // any serious waiting, e.g., for throttle, batching, or actual request handling. + let mut timer = shard.query_metrics.start_smgr_op(op, received_at); + let now = Instant::now(); + timer.observe_throttle_start(now); + let throttled = tokio::select! { + res = shard.pagestream_throttle.throttle(1, now) => res, + _ = shard.cancel.cancelled() => return Err(QueryError::Shutdown), + }; + timer.observe_throttle_done(throttled); + Ok(timer) + } + /// Post-condition: `batch` is Some() #[instrument(skip_all, level = tracing::Level::TRACE)] #[allow(clippy::boxed_local)] @@ -1139,8 +1541,11 @@ impl PageServerHandler { let (mut handler_results, span) = { // TODO: we unfortunately have to pin the future on the heap, since GetPage futures are huge and // won't fit on the stack. - let mut boxpinned = - Box::pin(self.pagestream_dispatch_batched_message(batch, io_concurrency, ctx)); + let mut boxpinned = Box::pin(Self::pagestream_dispatch_batched_message( + batch, + io_concurrency, + ctx, + )); log_slow( log_slow_name, LOG_SLOW_GETPAGE_THRESHOLD, @@ -1171,7 +1576,7 @@ impl PageServerHandler { let mut flush_timers = Vec::with_capacity(handler_results.len()); for handler_result in &mut handler_results { let flush_timer = match handler_result { - Ok((_, timer)) => Some( + Ok((_response, timer, _ctx)) => Some( timer .observe_execution_end(flushing_start_time) .expect("we are the first caller"), @@ -1191,7 +1596,7 @@ impl PageServerHandler { // Some handler errors cause exit from pagestream protocol. // Other handler errors are sent back as an error message and we stay in pagestream protocol. for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) { - let response_msg = match handler_result { + let (response_msg, ctx) = match handler_result { Err(e) => match &e.err { PageStreamError::Shutdown => { // If we fail to fulfil a request during shutdown, which may be _because_ of @@ -1216,15 +1621,30 @@ impl PageServerHandler { error!("error reading relation or page version: {full:#}") }); - PagestreamBeMessage::Error(PagestreamErrorResponse { - req: e.req, - message: e.err.to_string(), - }) + ( + PagestreamBeMessage::Error(PagestreamErrorResponse { + req: e.req, + message: e.err.to_string(), + }), + None, + ) } }, - Ok((response_msg, _op_timer_already_observed)) => response_msg, + Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)), }; + let ctx = ctx.map(|req_ctx| { + RequestContextBuilder::from(&req_ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "FLUSH_RESPONSE", + ) + }) + .attached_child() + }); + // // marshal & transmit response message // @@ -1247,6 +1667,17 @@ impl PageServerHandler { )), None => futures::future::Either::Right(flush_fut), }; + + let flush_fut = if let Some(req_ctx) = ctx.as_ref() { + futures::future::Either::Left( + flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| { + current_perf_span.clone() + }), + ) + } else { + futures::future::Either::Right(flush_fut) + }; + // do it while respecting cancellation let _: () = async move { tokio::select! { @@ -1270,13 +1701,12 @@ impl PageServerHandler { /// Helper which dispatches a batched message to the appropriate handler. /// Returns a vec of results, along with the extracted trace span. async fn pagestream_dispatch_batched_message( - &mut self, batch: BatchedFeMessage, io_concurrency: IoConcurrency, ctx: &RequestContext, ) -> Result< ( - Vec>, + Vec>, Span, ), QueryError, @@ -1300,10 +1730,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_rel_exists_request(&shard, &req, &ctx) + Self::handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (PagestreamBeMessage::Exists(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1319,10 +1749,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_nblocks_request(&shard, &req, &ctx) + Self::handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1340,16 +1770,15 @@ impl PageServerHandler { { let npages = pages.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_get_page_at_lsn_request_batched( - &shard, - pages, - io_concurrency, - batch_break_reason, - &ctx, - ) - .instrument(span.clone()) - .await; + let res = Self::handle_get_page_at_lsn_request_batched( + &shard, + pages, + io_concurrency, + batch_break_reason, + &ctx, + ) + .instrument(span.clone()) + .await; assert_eq!(res.len(), npages); res }, @@ -1366,10 +1795,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_db_size_request(&shard, &req, &ctx) + Self::handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (PagestreamBeMessage::DbSize(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1385,10 +1814,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - self.handle_get_slru_segment_request(&shard, &req, &ctx) + Self::handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (PagestreamBeMessage::GetSlruSegment(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1406,8 +1835,7 @@ impl PageServerHandler { { let npages = requests.len(); trace!(npages, "handling getpage request"); - let res = self - .handle_test_request_batch(&shard, requests, &ctx) + let res = Self::handle_test_request_batch(&shard, requests, &ctx) .instrument(span.clone()) .await; assert_eq!(res.len(), npages); @@ -1740,12 +2168,25 @@ impl PageServerHandler { return Ok(()); } }; - let batch = match batch { + let mut batch = match batch { Ok(batch) => batch, Err(e) => { return Err(e); } }; + + if let BatchedFeMessage::GetPage { + pages, + span: _, + shard: _, + batch_break_reason: _, + } = &mut batch + { + for req in pages { + req.batch_wait_ctx.take(); + } + } + self.pagestream_handle_batched_message( pgb_writer, batch, @@ -1948,11 +2389,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_rel_exists_request( - &mut self, timeline: &Timeline, req: &PagestreamExistsRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -1974,19 +2414,15 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse { - req: *req, - exists, - })) + Ok(PagestreamExistsResponse { req: *req, exists }) } #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( - &mut self, timeline: &Timeline, req: &PagestreamNblocksRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2008,19 +2444,18 @@ impl PageServerHandler { ) .await?; - Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse { + Ok(PagestreamNblocksResponse { req: *req, n_blocks, - })) + }) } #[instrument(skip_all, fields(shard_id))] async fn handle_db_size_request( - &mut self, timeline: &Timeline, req: &PagestreamDbSizeRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2044,21 +2479,18 @@ impl PageServerHandler { .await?; let db_size = total_blocks as i64 * BLCKSZ as i64; - Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse { - req: *req, - db_size, - })) + Ok(PagestreamDbSizeResponse { req: *req, db_size }) } #[instrument(skip_all)] async fn handle_get_page_at_lsn_request_batched( - &mut self, timeline: &Timeline, - requests: smallvec::SmallVec<[BatchedGetPageRequest; 1]>, + requests: SmallVec<[BatchedGetPageRequest; 1]>, io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { debug_assert_current_span_has_tenant_and_timeline_id(); timeline @@ -2165,6 +2597,7 @@ impl PageServerHandler { page, }), req.timer, + req.ctx, ) }) .map_err(|e| BatchedPageStreamError { @@ -2177,11 +2610,10 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_slru_segment_request( - &mut self, timeline: &Timeline, req: &PagestreamGetSlruSegmentRequest, ctx: &RequestContext, - ) -> Result { + ) -> Result { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2196,20 +2628,18 @@ impl PageServerHandler { .ok_or(PageStreamError::BadRequest("invalid SLRU kind".into()))?; let segment = timeline.get_slru_segment(kind, req.segno, lsn, ctx).await?; - Ok(PagestreamBeMessage::GetSlruSegment( - PagestreamGetSlruSegmentResponse { req: *req, segment }, - )) + Ok(PagestreamGetSlruSegmentResponse { req: *req, segment }) } // NB: this impl mimics what we do for batched getpage requests. #[cfg(feature = "testing")] #[instrument(skip_all, fields(shard_id))] async fn handle_test_request_batch( - &mut self, timeline: &Timeline, requests: Vec, _ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { // real requests would do something with the timeline let mut results = Vec::with_capacity(requests.len()); for _req in requests.iter() { @@ -2236,6 +2666,10 @@ impl PageServerHandler { req: req.req.clone(), }), req.timer, + RequestContext::new( + TaskKind::PageRequestHandler, + DownloadBehavior::Warn, + ), ) }) .map_err(|e| BatchedPageStreamError { @@ -2276,15 +2710,6 @@ impl PageServerHandler { where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { - fn map_basebackup_error(err: BasebackupError) -> QueryError { - match err { - // TODO: passthrough the error site to the final error message? - BasebackupError::Client(e, _) => QueryError::Disconnected(ConnectionError::Io(e)), - BasebackupError::Server(e) => QueryError::Other(e), - BasebackupError::Shutdown => QueryError::Shutdown, - } - } - let started = std::time::Instant::now(); let timeline = self @@ -2342,8 +2767,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } else { let mut writer = BufWriter::new(pgb.copyout_writer()); @@ -2366,11 +2790,8 @@ impl PageServerHandler { from_cache = true; tokio::io::copy(&mut cached, &mut writer) .await - .map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,cached,copy", - )) + .map_err(|err| { + BasebackupError::Client(err, "handle_basebackup_request,cached,copy") })?; } else if gzip { let mut encoder = GzipEncoder::with_quality( @@ -2391,8 +2812,7 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; // shutdown the encoder to ensure the gzip footer is written encoder .shutdown() @@ -2408,15 +2828,12 @@ impl PageServerHandler { replica, &ctx, ) - .await - .map_err(map_basebackup_error)?; + .await?; } - writer.flush().await.map_err(|e| { - map_basebackup_error(BasebackupError::Client( - e, - "handle_basebackup_request,flush", - )) - })?; + writer + .flush() + .await + .map_err(|err| BasebackupError::Client(err, "handle_basebackup_request,flush"))?; } pgb.write_message_noflush(&BeMessage::CopyDone) @@ -2940,20 +3357,641 @@ where } } -impl From for QueryError { - fn from(e: GetActiveTenantError) -> Self { - match e { - GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( - ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), - ), - GetActiveTenantError::Cancelled - | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { - QueryError::Shutdown - } - e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), - e => QueryError::Other(anyhow::anyhow!(e)), +/// Serves the page service over gRPC. Dispatches to PageServerHandler for request processing. +/// +/// TODO: rename to PageServiceHandler when libpq impl is removed. +pub struct GrpcPageServiceHandler { + tenant_manager: Arc, + ctx: RequestContext, +} + +impl GrpcPageServiceHandler { + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of + /// relations and their sizes, as well as SLRU segments and similar data. + #[allow(clippy::result_large_err)] + fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { + match timeline.get_shard_index().shard_number.0 { + 0 => Ok(()), + shard => Err(tonic::Status::invalid_argument(format!( + "request must execute on shard zero (is shard {shard})", + ))), } } + + /// Generates a PagestreamRequest header from a ReadLsn and request ID. + fn make_hdr(read_lsn: page_api::ReadLsn, req_id: u64) -> PagestreamRequest { + PagestreamRequest { + reqid: req_id, + request_lsn: read_lsn.request_lsn, + not_modified_since: read_lsn + .not_modified_since_lsn + .unwrap_or(read_lsn.request_lsn), + } + } + + /// Acquires a timeline handle for the given request. + /// + /// TODO: during shard splits, the compute may still be sending requests to the parent shard + /// until the entire split is committed and the compute is notified. Consider installing a + /// temporary shard router from the parent to the children while the split is in progress. + /// + /// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage + /// the TimelineHandles lifecycle. + /// + /// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid + /// the unnecessary overhead. + async fn get_request_timeline( + &self, + req: &tonic::Request, + ) -> Result, GetActiveTimelineError> { + let ttid = *extract::(req); + let shard_index = *extract::(req); + let shard_selector = ShardSelector::Known(shard_index); + + TimelineHandles::new(self.tenant_manager.clone()) + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await + } + + /// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start. + /// Only errors if the timeline is shutting down. + /// + /// TODO: move timer construction to ObservabilityLayer (see TODO there). + /// TODO: decouple rate limiting (middleware?), and return SlowDown errors instead. + async fn record_op_start_and_throttle( + timeline: &Handle, + op: metrics::SmgrQueryType, + received_at: Instant, + ) -> Result { + let mut timer = PageServerHandler::record_op_start_and_throttle(timeline, op, received_at) + .await + .map_err(|err| match err { + // record_op_start_and_throttle() only returns Shutdown. + QueryError::Shutdown => tonic::Status::unavailable(format!("{err}")), + err => tonic::Status::internal(format!("unexpected error: {err}")), + })?; + timer.observe_execution_start(Instant::now()); + Ok(timer) + } + + /// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC. + /// + /// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse + /// with an appropriate status code instead. + /// + /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send + /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or + /// split them up in the client or server. + #[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))] + async fn get_page( + ctx: &RequestContext, + timeline: &WeakHandle, + req: proto::GetPageRequest, + io_concurrency: IoConcurrency, + ) -> Result { + let received_at = Instant::now(); + let timeline = timeline.upgrade()?; + let ctx = ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + let req: page_api::GetPageRequest = req.try_into()?; + + span_record!( + req_id = %req.request_id, + rel = %req.rel, + blkno = %req.block_numbers[0], + blks = %req.block_numbers.len(), + lsn = %req.read_lsn, + ); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard + let effective_lsn = match PageServerHandler::effective_request_lsn( + &timeline, + timeline.get_last_record_lsn(), + req.read_lsn.request_lsn, + req.read_lsn + .not_modified_since_lsn + .unwrap_or(req.read_lsn.request_lsn), + &latest_gc_cutoff_lsn, + ) { + Ok(lsn) => lsn, + Err(err) => return err.into_get_page_response(req.request_id), + }; + + let mut batch = SmallVec::with_capacity(req.block_numbers.len()); + for blkno in req.block_numbers { + // TODO: this creates one timer per page and throttles it. We should have a timer for + // the entire batch, and throttle only the batch, but this is equivalent to what + // PageServerHandler does already so we keep it for now. + let timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetPageAtLsn, + received_at, + ) + .await?; + + batch.push(BatchedGetPageRequest { + req: PagestreamGetPageRequest { + hdr: Self::make_hdr(req.read_lsn, req.request_id), + rel: req.rel, + blkno, + }, + lsn_range: LsnRange { + effective_lsn, + request_lsn: req.read_lsn.request_lsn, + }, + timer, + ctx: ctx.attached_child(), + batch_wait_ctx: None, // TODO: add tracing + }); + } + + // TODO: this does a relation size query for every page in the batch. Since this batch is + // all for one relation, we could do this only once. However, this is not the case for the + // libpq implementation. + let results = PageServerHandler::handle_get_page_at_lsn_request_batched( + &timeline, + batch, + io_concurrency, + GetPageBatchBreakReason::BatchFull, // TODO: not relevant for gRPC batches + &ctx, + ) + .await; + + let mut resp = page_api::GetPageResponse { + request_id: req.request_id, + status_code: page_api::GetPageStatusCode::Ok, + reason: None, + page_images: SmallVec::with_capacity(results.len()), + }; + + for result in results { + match result { + Ok((PagestreamBeMessage::GetPage(r), _, _)) => resp.page_images.push(r.page), + Ok((resp, _, _)) => { + return Err(tonic::Status::internal(format!( + "unexpected response: {resp:?}" + ))); + } + Err(err) => return err.err.into_get_page_response(req.request_id), + }; + } + + Ok(resp.into()) + } +} + +/// Implements the gRPC page service. +/// +/// TODO: cancellation. +/// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code. +#[tonic::async_trait] +impl proto::PageService for GrpcPageServiceHandler { + type GetBaseBackupStream = Pin< + Box> + Send>, + >; + + type GetPagesStream = + Pin> + Send>>; + + #[instrument(skip_all, fields(rel, lsn))] + async fn check_rel_exists( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamExistsRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelExists, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?; + let resp: page_api::CheckRelExistsResponse = resp.exists; + Ok(tonic::Response::new(resp.into())) + } + + // TODO: ensure clients use gzip compression for the stream. + #[instrument(skip_all, fields(lsn))] + async fn get_base_backup( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + // Send 64 KB chunks to avoid large memory allocations. + const CHUNK_SIZE: usize = 64 * 1024; + + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_timeline(&timeline); + + // Validate the request, decorate the span, and wait for the LSN to arrive. + // + // TODO: this requires a read LSN, is that ok? + Self::ensure_shard_zero(&timeline)?; + if timeline.is_archived() == Some(true) { + return Err(tonic::Status::failed_precondition("timeline is archived")); + } + let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?; + + span_record!(lsn=%req.read_lsn); + + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); + timeline + .wait_lsn( + req.read_lsn.request_lsn, + WaitLsnWaiter::PageService, + WaitLsnTimeout::Default, + &ctx, + ) + .await?; + timeline + .check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn) + .map_err(|err| { + tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) + })?; + + // Spawn a task to run the basebackup. + // + // TODO: do we need to support full base backups, for debugging? + let span = Span::current(); + let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE); + let jh = tokio::spawn(async move { + let result = basebackup::send_basebackup_tarball( + &mut simplex_write, + &timeline, + Some(req.read_lsn.request_lsn), + None, + false, + req.replica, + &ctx, + ) + .instrument(span) // propagate request span + .await; + simplex_write.shutdown().await.map_err(|err| { + BasebackupError::Server(anyhow!("simplex shutdown failed: {err}")) + })?; + result + }); + + // Emit chunks of size CHUNK_SIZE. + let chunks = async_stream::try_stream! { + let mut chunk = BytesMut::with_capacity(CHUNK_SIZE); + loop { + let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { + tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) + })?; + + // If we read 0 bytes, either the chunk is full or the stream is closed. + if n == 0 { + if chunk.is_empty() { + break; + } + yield proto::GetBaseBackupResponseChunk::try_from(chunk.clone().freeze())?; + chunk.clear(); + } + } + // Wait for the basebackup task to exit and check for errors. + jh.await.map_err(|err| { + tonic::Status::internal(format!("basebackup failed: {err}")) + })??; + }; + + Ok(tonic::Response::new(Box::pin(chunks))) + } + + #[instrument(skip_all, fields(db_oid, lsn))] + async fn get_db_size( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?; + + span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn); + + let req = PagestreamDbSizeRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + dbnode: req.db_oid, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetDbSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_db_size_request(&timeline, &req, &ctx).await?; + let resp = resp.db_size as page_api::GetDbSizeResponse; + Ok(tonic::Response::new(resp.into())) + } + + // NB: don't instrument this, instrument each streamed request. + async fn get_pages( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + // Extract the timeline from the request and check that it exists. + let ttid = *extract::(&req); + let shard_index = *extract::(&req); + let shard_selector = ShardSelector::Known(shard_index); + + let mut handles = TimelineHandles::new(self.tenant_manager.clone()); + handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await?; + + let span = Span::current(); + let ctx = self.ctx.attached_child(); + let mut reqs = req.into_inner(); + + let resps = async_stream::try_stream! { + let timeline = handles + .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .await? + .downgrade(); + while let Some(req) = reqs.message().await? { + // TODO: implement IoConcurrency sidecar. + yield Self::get_page(&ctx, &timeline, req, IoConcurrency::Sequential) + .instrument(span.clone()) // propagate request span + .await? + } + }; + + Ok(tonic::Response::new(Box::pin(resps))) + } + + #[instrument(skip_all, fields(rel, lsn))] + async fn get_rel_size( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; + + span_record!(rel=%req.rel, lsn=%req.read_lsn); + + let req = PagestreamNblocksRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + rel: req.rel, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetRelSize, + received_at, + ) + .await?; + + let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetRelSizeResponse = resp.n_blocks; + Ok(tonic::Response::new(resp.into())) + } + + #[instrument(skip_all, fields(kind, segno, lsn))] + async fn get_slru_segment( + &self, + req: tonic::Request, + ) -> Result, tonic::Status> { + let received_at = extract::(&req).0; + let timeline = self.get_request_timeline(&req).await?; + let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); + + // Validate the request, decorate the span, and convert it to a Pagestream request. + Self::ensure_shard_zero(&timeline)?; + let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; + + span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); + + let req = PagestreamGetSlruSegmentRequest { + hdr: Self::make_hdr(req.read_lsn, 0), + kind: req.kind as u8, + segno: req.segno, + }; + + // Execute the request and convert the response. + let _timer = Self::record_op_start_and_throttle( + &timeline, + metrics::SmgrQueryType::GetSlruSegment, + received_at, + ) + .await?; + + let resp = + PageServerHandler::handle_get_slru_segment_request(&timeline, &req, &ctx).await?; + let resp: page_api::GetSlruSegmentResponse = resp.segment; + Ok(tonic::Response::new(resp.try_into()?)) + } +} + +/// gRPC middleware layer that handles observability concerns: +/// +/// * Creates and enters a tracing span. +/// * Records the request start time as a ReceivedAt request extension. +/// +/// TODO: add perf tracing. +/// TODO: add timing and metrics. +/// TODO: add logging. +#[derive(Clone)] +struct ObservabilityLayer; + +impl tower::Layer for ObservabilityLayer { + type Service = ObservabilityLayerService; + + fn layer(&self, inner: S) -> Self::Service { + Self::Service { inner } + } +} + +#[derive(Clone)] +struct ObservabilityLayerService { + inner: S, +} + +#[derive(Clone, Copy)] +struct ReceivedAt(Instant); + +impl tonic::server::NamedService for ObservabilityLayerService { + const NAME: &'static str = S::NAME; // propagate inner service name +} + +impl tower::Service> for ObservabilityLayerService +where + S: tower::Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn call(&mut self, mut req: http::Request) -> Self::Future { + // Record the request start time as a request extension. + // + // TODO: we should start a timer here instead, but it currently requires a timeline handle + // and SmgrQueryType, which we don't have yet. Refactor it to provide it later. + req.extensions_mut().insert(ReceivedAt(Instant::now())); + + // Create a basic tracing span. Enter the span for the current thread (to use it for inner + // sync code like interceptors), and instrument the future (to use it for inner async code + // like the page service itself). + // + // The instrument() call below is not sufficient. It only affects the returned future, and + // only takes effect when the caller polls it. Any sync code executed when we call + // self.inner.call() below (such as interceptors) runs outside of the returned future, and + // is not affected by it. We therefore have to enter the span on the current thread too. + let span = info_span!( + "grpc:pageservice", + // Set by TenantMetadataInterceptor. + tenant_id = field::Empty, + timeline_id = field::Empty, + shard_id = field::Empty, + ); + let _guard = span.enter(); + + Box::pin(self.inner.call(req).instrument(span.clone())) + } + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } +} + +/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type +/// TenantTimelineId and ShardIndex. +#[derive(Clone)] +struct TenantMetadataInterceptor; + +impl tonic::service::Interceptor for TenantMetadataInterceptor { + fn call(&mut self, mut req: tonic::Request<()>) -> Result, tonic::Status> { + // Decode the tenant ID. + let tenant_id = req + .metadata() + .get("neon-tenant-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-tenant-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + let tenant_id = TenantId::from_str(tenant_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + + // Decode the timeline ID. + let timeline_id = req + .metadata() + .get("neon-timeline-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-timeline-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + let timeline_id = TimelineId::from_str(timeline_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + + // Decode the shard ID. + let shard_id = req + .metadata() + .get("neon-shard-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + let shard_id = ShardIndex::from_str(shard_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + + // Stash them in the request. + let extensions = req.extensions_mut(); + extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); + extensions.insert(shard_id); + + // Decorate the tracing span. + span_record!(%tenant_id, %timeline_id, %shard_id); + + Ok(req) + } +} + +/// Authenticates gRPC page service requests. +#[derive(Clone)] +struct TenantAuthInterceptor { + auth: Option>, +} + +impl TenantAuthInterceptor { + fn new(auth: Option>) -> Self { + Self { auth } + } +} + +impl tonic::service::Interceptor for TenantAuthInterceptor { + fn call(&mut self, req: tonic::Request<()>) -> Result, tonic::Status> { + // Do nothing if auth is disabled. + let Some(auth) = self.auth.as_ref() else { + return Ok(req); + }; + + // Fetch the tenant ID from the request extensions (set by TenantMetadataInterceptor). + let TenantTimelineId { tenant_id, .. } = *extract::(&req); + + // Fetch and decode the JWT token. + let jwt = req + .metadata() + .get("authorization") + .ok_or_else(|| tonic::Status::unauthenticated("no authorization header"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid authorization header"))? + .strip_prefix("Bearer ") + .ok_or_else(|| tonic::Status::invalid_argument("invalid authorization header"))? + .trim(); + let jwtdata: TokenData = auth + .decode(jwt) + .map_err(|err| tonic::Status::invalid_argument(format!("invalid JWT token: {err}")))?; + let claims = jwtdata.claims; + + // Check if the token is valid for this tenant. + check_permission(&claims, Some(tenant_id)) + .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; + + // TODO: consider stashing the claims in the request extensions, if needed. + + Ok(req) + } +} + +/// Extracts the given type from the request extensions, or panics if it is missing. +fn extract(req: &tonic::Request) -> &T { + extract_from(req.extensions()) +} + +/// Extract the given type from the request extensions, or panics if it is missing. This variant +/// can extract both from a tonic::Request and http::Request. +fn extract_from(ext: &http::Extensions) -> &T { + let Some(value) = ext.get::() else { + let name = std::any::type_name::(); + panic!("extension {name} should be set by middleware"); + }; + value } #[derive(Debug, thiserror::Error)] @@ -2974,10 +4012,72 @@ impl From for QueryError { } } -impl From for QueryError { - fn from(e: crate::tenant::timeline::handle::HandleUpgradeError) -> Self { +impl From for tonic::Status { + fn from(err: GetActiveTimelineError) -> Self { + let message = err.to_string(); + let code = match err { + GetActiveTimelineError::Tenant(err) => tonic::Status::from(err).code(), + GetActiveTimelineError::Timeline(err) => tonic::Status::from(err).code(), + }; + tonic::Status::new(code, message) + } +} + +impl From for tonic::Status { + fn from(err: GetTimelineError) -> Self { + use tonic::Code; + let code = match &err { + GetTimelineError::NotFound { .. } => Code::NotFound, + GetTimelineError::NotActive { .. } => Code::Unavailable, + GetTimelineError::ShuttingDown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: GetActiveTenantError) -> Self { match e { - crate::tenant::timeline::handle::HandleUpgradeError::ShutDown => QueryError::Shutdown, + GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected( + ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())), + ), + GetActiveTenantError::Cancelled + | GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { + QueryError::Shutdown + } + e @ GetActiveTenantError::NotFound(_) => QueryError::NotFound(format!("{e}").into()), + e => QueryError::Other(anyhow::anyhow!(e)), + } + } +} + +impl From for tonic::Status { + fn from(err: GetActiveTenantError) -> Self { + use tonic::Code; + let code = match &err { + GetActiveTenantError::Broken(_) => Code::Internal, + GetActiveTenantError::Cancelled => Code::Unavailable, + GetActiveTenantError::NotFound(_) => Code::NotFound, + GetActiveTenantError::SwitchedTenant => Code::Unavailable, + GetActiveTenantError::WaitForActiveTimeout { .. } => Code::Unavailable, + GetActiveTenantError::WillNotBecomeActive(_) => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + +impl From for QueryError { + fn from(e: HandleUpgradeError) -> Self { + match e { + HandleUpgradeError::ShutDown => QueryError::Shutdown, + } + } +} + +impl From for tonic::Status { + fn from(err: HandleUpgradeError) -> Self { + match err { + HandleUpgradeError::ShutDown => tonic::Status::unavailable("timeline is shutting down"), } } } diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c6f3929257..b6f11b744b 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -471,8 +471,19 @@ impl Timeline { let rels = self.list_rels(spcnode, dbnode, version, ctx).await?; + if rels.is_empty() { + return Ok(0); + } + + // Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`. + let reldir_key = rel_dir_to_key(spcnode, dbnode); + let buf = version.get(self, reldir_key, ctx).await?; + let reldir = RelDirectory::des(&buf)?; + for rel in rels { - let n_blocks = self.get_rel_size(rel, version, ctx).await?; + let n_blocks = self + .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx) + .await?; total_blocks += n_blocks as usize; } Ok(total_blocks) @@ -487,6 +498,19 @@ impl Timeline { tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_size_in_reldir(tag, version, None, ctx).await + } + + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`. + pub(crate) async fn get_rel_size_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -499,7 +523,9 @@ impl Timeline { } if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM) - && !self.get_rel_exists(tag, version, ctx).await? + && !self + .get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx) + .await? { // FIXME: Postgres sometimes calls smgrcreate() to create // FSM, and smgrnblocks() on it immediately afterwards, @@ -521,11 +547,28 @@ impl Timeline { /// /// Only shard 0 has a full view of the relations. Other shards only know about relations that /// the shard stores pages for. + /// pub(crate) async fn get_rel_exists( &self, tag: RelTag, version: Version<'_>, ctx: &RequestContext, + ) -> Result { + self.get_rel_exists_in_reldir(tag, version, None, ctx).await + } + + /// Does the relation exist? With a cached deserialized `RelDirectory`. + /// + /// There are some cases where the caller loops across all relations. In that specific case, + /// the caller should obtain the deserialized `RelDirectory` first and then call this function + /// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing + /// a new API (e.g., `get_rel_exists_batched`). + pub(crate) async fn get_rel_exists_in_reldir( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, ) -> Result { if tag.relnode == 0 { return Err(PageReconstructError::Other( @@ -568,6 +611,17 @@ impl Timeline { // fetch directory listing (old) let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } let buf = version.get(self, key, ctx).await?; let dir = RelDirectory::des(&buf)?; diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 55272b2125..29897af642 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -276,9 +276,10 @@ pub enum TaskKind { // HTTP endpoint listener. HttpEndpointListener, - // Task that handles a single connection. A PageRequestHandler task - // starts detached from any particular tenant or timeline, but it can be - // associated with one later, after receiving a command from the client. + /// Task that handles a single page service connection. A PageRequestHandler + /// task starts detached from any particular tenant or timeline, but it can + /// be associated with one later, after receiving a command from the client. + /// Also used for the gRPC page service API, including the main server task. PageRequestHandler, /// Manages the WAL receiver connection for one timeline. diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index bf3f71e35a..308ada3fa1 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -84,6 +84,7 @@ use crate::context; use crate::context::RequestContextBuilder; use crate::context::{DownloadBehavior, RequestContext}; use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError}; +use crate::feature_resolver::FeatureResolver; use crate::l0_flush::L0FlushGlobalState; use crate::metrics::{ BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS, @@ -159,6 +160,7 @@ pub struct TenantSharedResources { pub deletion_queue_client: DeletionQueueClient, pub l0_flush_global_state: L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } /// A [`TenantShard`] is really an _attached_ tenant. The configuration @@ -298,7 +300,7 @@ pub struct TenantShard { /// as in progress. /// * Imported timelines are removed when the storage controller calls the post timeline /// import activation endpoint. - timelines_importing: std::sync::Mutex>, + timelines_importing: std::sync::Mutex>>, /// The last tenant manifest known to be in remote storage. None if the manifest has not yet /// been either downloaded or uploaded. Always Some after tenant attach. @@ -380,6 +382,8 @@ pub struct TenantShard { pub(crate) gc_block: gc_block::GcBlock, l0_flush_global_state: L0FlushGlobalState, + + pub(crate) feature_resolver: FeatureResolver, } impl std::fmt::Debug for TenantShard { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -668,6 +672,7 @@ pub enum MaybeOffloaded { pub enum TimelineOrOffloaded { Timeline(Arc), Offloaded(Arc), + Importing(Arc), } impl TimelineOrOffloaded { @@ -679,6 +684,9 @@ impl TimelineOrOffloaded { TimelineOrOffloaded::Offloaded(offloaded) => { TimelineOrOffloadedArcRef::Offloaded(offloaded) } + TimelineOrOffloaded::Importing(importing) => { + TimelineOrOffloadedArcRef::Importing(importing) + } } } pub fn tenant_shard_id(&self) -> TenantShardId { @@ -691,12 +699,16 @@ impl TimelineOrOffloaded { match self { TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress, TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress, + TimelineOrOffloaded::Importing(importing) => &importing.delete_progress, } } fn maybe_remote_client(&self) -> Option> { match self { TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()), TimelineOrOffloaded::Offloaded(_offloaded) => None, + TimelineOrOffloaded::Importing(importing) => { + Some(importing.timeline.remote_client.clone()) + } } } } @@ -704,6 +716,7 @@ impl TimelineOrOffloaded { pub enum TimelineOrOffloadedArcRef<'a> { Timeline(&'a Arc), Offloaded(&'a Arc), + Importing(&'a Arc), } impl TimelineOrOffloadedArcRef<'_> { @@ -711,12 +724,14 @@ impl TimelineOrOffloadedArcRef<'_> { match self { TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id, TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id, + TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id, } } pub fn timeline_id(&self) -> TimelineId { match self { TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id, TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id, + TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id, } } } @@ -733,6 +748,12 @@ impl<'a> From<&'a Arc> for TimelineOrOffloadedArcRef<'a> { } } +impl<'a> From<&'a Arc> for TimelineOrOffloadedArcRef<'a> { + fn from(timeline: &'a Arc) -> Self { + Self::Importing(timeline) + } +} + #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum GetTimelineError { #[error("Timeline is shutting down")] @@ -860,6 +881,14 @@ impl Debug for SetStoppingError { } } +#[derive(thiserror::Error, Debug)] +pub(crate) enum FinalizeTimelineImportError { + #[error("Import task not done yet")] + ImportTaskStillRunning, + #[error("Shutting down")] + ShuttingDown, +} + /// Arguments to [`TenantShard::create_timeline`]. /// /// Not usable as an idempotency key for timeline creation because if [`CreateTimelineParamsBranch::ancestor_start_lsn`] @@ -1146,10 +1175,20 @@ impl TenantShard { ctx, )?; let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); - anyhow::ensure!( - disk_consistent_lsn.is_valid(), - "Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn" - ); + + if !disk_consistent_lsn.is_valid() { + // As opposed to normal timelines which get initialised with a disk consitent LSN + // via initdb, imported timelines start from 0. If the import task stops before + // it advances disk consitent LSN, allow it to resume. + let in_progress_import = import_pgdata + .as_ref() + .map(|import| !import.is_done()) + .unwrap_or(false); + if !in_progress_import { + anyhow::bail!("Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn"); + } + } + assert_eq!( disk_consistent_lsn, metadata.disk_consistent_lsn(), @@ -1243,20 +1282,25 @@ impl TenantShard { } } - // Sanity check: a timeline should have some content. - anyhow::ensure!( - ancestor.is_some() - || timeline - .layers - .read() - .await - .layer_map() - .expect("currently loading, layer manager cannot be shutdown already") - .iter_historic_layers() - .next() - .is_some(), - "Timeline has no ancestor and no layer files" - ); + if disk_consistent_lsn.is_valid() { + // Sanity check: a timeline should have some content. + // Exception: importing timelines might not yet have any + anyhow::ensure!( + ancestor.is_some() + || timeline + .layers + .read() + .await + .layer_map() + .expect( + "currently loading, layer manager cannot be shutdown already" + ) + .iter_historic_layers() + .next() + .is_some(), + "Timeline has no ancestor and no layer files" + ); + } Ok(TimelineInitAndSyncResult::ReadyToActivate) } @@ -1292,6 +1336,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } = resources; let attach_mode = attached_conf.location.attach_mode; @@ -1308,6 +1353,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, )); // The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if @@ -1760,20 +1806,25 @@ impl TenantShard { }, ) => { let timeline_id = timeline.timeline_id; + let import_task_gate = Gate::default(); + let import_task_guard = import_task_gate.enter().unwrap(); let import_task_handle = tokio::task::spawn(self.clone().create_timeline_import_pgdata_task( timeline.clone(), import_pgdata, guard, + import_task_guard, ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn), )); let prev = self.timelines_importing.lock().unwrap().insert( timeline_id, - ImportingTimeline { + Arc::new(ImportingTimeline { timeline: timeline.clone(), import_task_handle, - }, + import_task_gate, + delete_progress: TimelineDeleteProgress::default(), + }), ); assert!(prev.is_none()); @@ -2391,6 +2442,17 @@ impl TenantShard { .collect() } + /// Lists timelines the tenant contains. + /// It's up to callers to omit certain timelines that are not considered ready for use. + pub fn list_importing_timelines(&self) -> Vec> { + self.timelines_importing + .lock() + .unwrap() + .values() + .map(Arc::clone) + .collect() + } + /// Lists timelines the tenant manages, including offloaded ones. /// /// It's up to callers to omit certain timelines that are not considered ready for use. @@ -2824,19 +2886,25 @@ impl TenantShard { let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself(); + let import_task_gate = Gate::default(); + let import_task_guard = import_task_gate.enter().unwrap(); + let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task( timeline.clone(), index_part, timeline_create_guard, + import_task_guard, timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn), )); let prev = self.timelines_importing.lock().unwrap().insert( timeline.timeline_id, - ImportingTimeline { + Arc::new(ImportingTimeline { timeline: timeline.clone(), import_task_handle, - }, + import_task_gate, + delete_progress: TimelineDeleteProgress::default(), + }), ); // Idempotency is enforced higher up the stack @@ -2854,13 +2922,13 @@ impl TenantShard { pub(crate) async fn finalize_importing_timeline( &self, timeline_id: TimelineId, - ) -> anyhow::Result<()> { + ) -> Result<(), FinalizeTimelineImportError> { let timeline = { let locked = self.timelines_importing.lock().unwrap(); match locked.get(&timeline_id) { Some(importing_timeline) => { if !importing_timeline.import_task_handle.is_finished() { - return Err(anyhow::anyhow!("Import task not done yet")); + return Err(FinalizeTimelineImportError::ImportTaskStillRunning); } importing_timeline.timeline.clone() @@ -2873,8 +2941,13 @@ impl TenantShard { timeline .remote_client - .schedule_index_upload_for_import_pgdata_finalize()?; - timeline.remote_client.wait_completion().await?; + .schedule_index_upload_for_import_pgdata_finalize() + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; + timeline + .remote_client + .wait_completion() + .await + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; self.timelines_importing .lock() @@ -2890,6 +2963,7 @@ impl TenantShard { timeline: Arc, index_part: import_pgdata::index_part_format::Root, timeline_create_guard: TimelineCreateGuard, + _import_task_guard: GateGuard, ctx: RequestContext, ) { debug_assert_current_span_has_tenant_and_timeline_id(); @@ -3135,11 +3209,18 @@ impl TenantShard { .or_insert_with(|| Arc::new(GcCompactionQueue::new())) .clone() }; + let gc_compaction_strategy = self + .feature_resolver + .evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id) + .ok(); + let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy) + } else { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id) + }; outcome = queue .iteration(cancel, ctx, &self.gc_block, &timeline) - .instrument( - info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id), - ) + .instrument(span) .await?; } @@ -3471,8 +3552,9 @@ impl TenantShard { let mut timelines_importing = self.timelines_importing.lock().unwrap(); timelines_importing .drain() - .for_each(|(_timeline_id, importing_timeline)| { - importing_timeline.shutdown(); + .for_each(|(timeline_id, importing_timeline)| { + let span = tracing::info_span!("importing_timeline_shutdown", %timeline_id); + js.spawn(async move { importing_timeline.shutdown().instrument(span).await }); }); } // test_long_timeline_create_then_tenant_delete is leaning on this message @@ -3793,6 +3875,9 @@ impl TenantShard { .build_timeline_client(offloaded.timeline_id, self.remote_storage.clone()); Arc::new(remote_client) } + TimelineOrOffloadedArcRef::Importing(_) => { + unreachable!("Importing timelines are not included in the iterator") + } }; // Shut down the timeline's remote client: this means that the indices we write @@ -4247,6 +4332,7 @@ impl TenantShard { deletion_queue_client: DeletionQueueClient, l0_flush_global_state: L0FlushGlobalState, basebackup_prepare_sender: BasebackupPrepareSender, + feature_resolver: FeatureResolver, ) -> TenantShard { assert!(!attached_conf.location.generation.is_none()); @@ -4351,6 +4437,7 @@ impl TenantShard { gc_block: Default::default(), l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } } @@ -5000,6 +5087,14 @@ impl TenantShard { info!("timeline already exists but is offloaded"); Err(CreateTimelineError::Conflict) } + Err(TimelineExclusionError::AlreadyExists { + existing: TimelineOrOffloaded::Importing(_existing), + .. + }) => { + // If there's a timeline already importing, then we would hit + // the [`TimelineExclusionError::AlreadyCreating`] branch above. + unreachable!("Importing timelines hold the creation guard") + } Err(TimelineExclusionError::AlreadyExists { existing: TimelineOrOffloaded::Timeline(existing), arg, @@ -5271,6 +5366,7 @@ impl TenantShard { l0_compaction_trigger: self.l0_compaction_trigger.clone(), l0_flush_global_state: self.l0_flush_global_state.clone(), basebackup_prepare_sender: self.basebackup_prepare_sender.clone(), + feature_resolver: self.feature_resolver.clone(), } } @@ -5736,6 +5832,7 @@ pub(crate) mod harness { pub conf: &'static PageServerConf, pub tenant_conf: pageserver_api::models::TenantConfig, pub tenant_shard_id: TenantShardId, + pub shard_identity: ShardIdentity, pub generation: Generation, pub shard: ShardIndex, pub remote_storage: GenericRemoteStorage, @@ -5803,6 +5900,7 @@ pub(crate) mod harness { conf, tenant_conf, tenant_shard_id, + shard_identity, generation, shard, remote_storage, @@ -5864,8 +5962,7 @@ pub(crate) mod harness { &ShardParameters::default(), )) .unwrap(), - // This is a legacy/test code path: sharding isn't supported here. - ShardIdentity::unsharded(), + self.shard_identity, Some(walredo_mgr), self.tenant_shard_id, self.remote_storage.clone(), @@ -5873,6 +5970,7 @@ pub(crate) mod harness { // TODO: ideally we should run all unit tests with both configs L0FlushGlobalState::new(L0FlushConfig::default()), basebackup_requst_sender, + FeatureResolver::new_disabled(), )); let preload = tenant @@ -5986,6 +6084,7 @@ mod tests { use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn}; use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery}; use utils::id::TenantId; + use utils::shard::{ShardCount, ShardNumber}; use super::*; use crate::DEFAULT_PG_VERSION; @@ -8314,10 +8413,24 @@ mod tests { } tline.freeze_and_flush().await?; + // Force layers to L1 + tline + .compact( + &cancel, + { + let mut flags = EnumSet::new(); + flags.insert(CompactFlags::ForceL0Compaction); + flags + }, + &ctx, + ) + .await?; if iter % 5 == 0 { + let scan_lsn = Lsn(lsn.0 + 1); + info!("scanning at {}", scan_lsn); let (_, before_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; tline .compact( @@ -8326,13 +8439,14 @@ mod tests { let mut flags = EnumSet::new(); flags.insert(CompactFlags::ForceImageLayerCreation); flags.insert(CompactFlags::ForceRepartition); + flags.insert(CompactFlags::ForceL0Compaction); flags }, &ctx, ) .await?; let (_, after_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; assert!( after_delta_file_accessed < before_delta_file_accessed, @@ -8773,6 +8887,8 @@ mod tests { let cancel = CancellationToken::new(); + // Image layer creation happens on the disk_consistent_lsn so we need to force set it now. + tline.force_set_disk_consistent_lsn(Lsn(0x40)); tline .compact( &cancel, @@ -8786,8 +8902,7 @@ mod tests { ) .await .unwrap(); - - // Image layers are created at last_record_lsn + // Image layers are created at repartition LSN let images = tline .inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone()) .await @@ -9305,6 +9420,77 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> { + // + // Setup + // + let harness = TenantHarness::create_custom( + "test_failed_flush_should_not_upload_disk_consistent_lsn", + pageserver_api::models::TenantConfig::default(), + TenantId::generate(), + ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(), + Generation::new(1), + ) + .await?; + let (tenant, ctx) = harness.load().await; + + let timeline = tenant + .create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx) + .await?; + assert_eq!(timeline.get_shard_identity().count, ShardCount(4)); + let mut writer = timeline.writer().await; + writer + .put( + *TEST_KEY, + Lsn(0x20), + &Value::Image(test_img("foo at 0x20")), + &ctx, + ) + .await?; + writer.finish_write(Lsn(0x20)); + drop(writer); + timeline.freeze_and_flush().await.unwrap(); + + timeline.remote_client.wait_completion().await.unwrap(); + let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); + let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected(); + assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn); + + // + // Test + // + + let mut writer = timeline.writer().await; + writer + .put( + *TEST_KEY, + Lsn(0x30), + &Value::Image(test_img("foo at 0x30")), + &ctx, + ) + .await?; + writer.finish_write(Lsn(0x30)); + drop(writer); + + fail::cfg( + "flush-layer-before-update-remote-consistent-lsn", + "return()", + ) + .unwrap(); + + let flush_res = timeline.freeze_and_flush().await; + // if flush failed, the disk/remote consistent LSN should not be updated + assert!(flush_res.is_err()); + assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn()); + assert_eq!( + remote_consistent_lsn, + timeline.get_remote_consistent_lsn_projected() + ); + + Ok(()) + } + #[cfg(feature = "testing")] #[tokio::test] async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> { diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 21d68495f7..fd65000379 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1348,6 +1348,21 @@ impl RemoteTimelineClient { Ok(()) } + pub(crate) fn schedule_unlinking_of_layers_from_index_part( + self: &Arc, + names: I, + ) -> Result<(), NotInitialized> + where + I: IntoIterator, + { + let mut guard = self.upload_queue.lock().unwrap(); + let upload_queue = guard.initialized_mut()?; + + self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names); + + Ok(()) + } + /// Update the remote index file, removing the to-be-deleted files from the index, /// allowing scheduling of actual deletions later. fn schedule_unlinking_of_layers_from_index_part0( diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 54dc3b2d0b..9ddbe404d2 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -103,6 +103,7 @@ use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32}; +use crate::feature_resolver::FeatureResolver; use crate::keyspace::{KeyPartitioning, KeySpace}; use crate::l0_flush::{self, L0FlushGlobalState}; use crate::metrics::{ @@ -198,6 +199,7 @@ pub struct TimelineResources { pub l0_compaction_trigger: Arc, pub l0_flush_global_state: l0_flush::L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } pub struct Timeline { @@ -444,6 +446,8 @@ pub struct Timeline { /// A channel to send async requests to prepare a basebackup for the basebackup cache. basebackup_prepare_sender: BasebackupPrepareSender, + + feature_resolver: FeatureResolver, } pub(crate) enum PreviousHeatmap { @@ -946,6 +950,18 @@ pub(crate) enum WaitLsnError { Timeout(String), } +impl From for tonic::Status { + fn from(err: WaitLsnError) -> Self { + use tonic::Code; + let code = match &err { + WaitLsnError::Timeout(_) => Code::Internal, + WaitLsnError::BadState(_) => Code::Internal, + WaitLsnError::Shutdown => Code::Unavailable, + }; + tonic::Status::new(code, err.to_string()) + } +} + // The impls below achieve cancellation mapping for errors. // Perhaps there's a way of achieving this with less cruft. @@ -3072,6 +3088,8 @@ impl Timeline { wait_lsn_log_slow: tokio::sync::Semaphore::new(1), basebackup_prepare_sender: resources.basebackup_prepare_sender, + + feature_resolver: resources.feature_resolver, }; result.repartition_threshold = @@ -4761,7 +4779,10 @@ impl Timeline { || !flushed_to_lsn.is_valid() ); - if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 { + if flushed_to_lsn < frozen_to_lsn + && self.shard_identity.count.count() > 1 + && result.is_ok() + { // If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised // to us via layer_flush_start_rx, then advance it here. // @@ -4906,6 +4927,7 @@ impl Timeline { LastImageLayerCreationStatus::Initial, false, // don't yield for L0, we're flushing L0 ) + .instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn)) .await?; debug_assert!( matches!(is_complete, LastImageLayerCreationStatus::Complete), @@ -4939,6 +4961,10 @@ impl Timeline { return Err(FlushLayerError::Cancelled); } + fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| { + Err(FlushLayerError::Other(anyhow!("failpoint").into())) + }); + let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1); // The new on-disk layers are now in the layer map. We can remove the @@ -5462,7 +5488,8 @@ impl Timeline { /// Returns the image layers generated and an enum indicating whether the process is fully completed. /// true = we have generate all image layers, false = we preempt the process for L0 compaction. - #[tracing::instrument(skip_all, fields(%lsn, %mode))] + /// + /// `partition_mode` is only for logging purpose and is not used anywhere in this function. async fn create_image_layers( self: &Arc, partitioning: &KeyPartitioning, diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 0e4b14c3e4..72ca0f9cc1 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -206,8 +206,8 @@ pub struct GcCompactionQueue { } static CONCURRENT_GC_COMPACTION_TASKS: Lazy> = Lazy::new(|| { - // Only allow two timelines on one pageserver to run gc compaction at a time. - Arc::new(Semaphore::new(2)) + // Only allow one timeline on one pageserver to run gc compaction at a time. + Arc::new(Semaphore::new(1)) }); impl GcCompactionQueue { @@ -1278,11 +1278,55 @@ impl Timeline { } let gc_cutoff = *self.applied_gc_cutoff_lsn.read(); + let l0_l1_boundary_lsn = { + // We do the repartition on the L0-L1 boundary. All data below the boundary + // are compacted by L0 with low read amplification, thus making the `repartition` + // function run fast. + let guard = self.layers.read().await; + guard + .all_persistent_layers() + .iter() + .map(|x| { + // Use the end LSN of delta layers OR the start LSN of image layers. + if x.is_delta { + x.lsn_range.end + } else { + x.lsn_range.start + } + }) + .max() + }; + + let (partition_mode, partition_lsn) = if cfg!(test) + || cfg!(feature = "testing") + || self + .feature_resolver + .evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id) + .is_ok() + { + let last_repartition_lsn = self.partitioning.read().1; + let lsn = match l0_l1_boundary_lsn { + Some(boundary) => gc_cutoff + .max(boundary) + .max(last_repartition_lsn) + .max(self.initdb_lsn) + .max(self.ancestor_lsn), + None => self.get_last_record_lsn(), + }; + if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn { + // Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it + ("l0_l1_boundary", self.get_last_record_lsn()) + } else { + ("l0_l1_boundary", lsn) + } + } else { + ("latest_record", self.get_last_record_lsn()) + }; // 2. Repartition and create image layers if necessary match self .repartition( - self.get_last_record_lsn(), + partition_lsn, self.get_compaction_target_size(), options.flags, ctx, @@ -1301,18 +1345,19 @@ impl Timeline { .extend(sparse_partitioning.into_dense().parts); // 3. Create new image layers for partitions that have been modified "enough". + let mode = if options + .flags + .contains(CompactFlags::ForceImageLayerCreation) + { + ImageLayerCreationMode::Force + } else { + ImageLayerCreationMode::Try + }; let (image_layers, outcome) = self .create_image_layers( &partitioning, lsn, - if options - .flags - .contains(CompactFlags::ForceImageLayerCreation) - { - ImageLayerCreationMode::Force - } else { - ImageLayerCreationMode::Try - }, + mode, &image_ctx, self.last_image_layer_creation_status .load() @@ -1320,6 +1365,7 @@ impl Timeline { .clone(), options.flags.contains(CompactFlags::YieldForL0), ) + .instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn)) .await .inspect_err(|err| { if let CreateImageLayersError::GetVectoredError( @@ -1344,7 +1390,8 @@ impl Timeline { } Ok(_) => { - info!("skipping repartitioning due to image compaction LSN being below GC cutoff"); + // This happens very frequently so we don't want to log it. + debug!("skipping repartitioning due to image compaction LSN being below GC cutoff"); } // Suppress errors when cancelled. diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index 1d4dd05e34..51bdd59f4f 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -121,6 +121,7 @@ async fn remove_maybe_offloaded_timeline_from_tenant( // This observes the locking order between timelines and timelines_offloaded let mut timelines = tenant.timelines.lock().unwrap(); let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap(); + let mut timelines_importing = tenant.timelines_importing.lock().unwrap(); let offloaded_children_exist = timelines_offloaded .iter() .any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id())); @@ -150,8 +151,12 @@ async fn remove_maybe_offloaded_timeline_from_tenant( .expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map"); offloaded_timeline.delete_from_ancestor_with_timelines(&timelines); } + TimelineOrOffloaded::Importing(importing) => { + timelines_importing.remove(&importing.timeline.timeline_id); + } } + drop(timelines_importing); drop(timelines_offloaded); drop(timelines); @@ -203,8 +208,17 @@ impl DeleteTimelineFlow { guard.mark_in_progress()?; // Now that the Timeline is in Stopping state, request all the related tasks to shut down. - if let TimelineOrOffloaded::Timeline(timeline) = &timeline { - timeline.shutdown(super::ShutdownMode::Hard).await; + // TODO(vlad): shut down imported timeline here + match &timeline { + TimelineOrOffloaded::Timeline(timeline) => { + timeline.shutdown(super::ShutdownMode::Hard).await; + } + TimelineOrOffloaded::Importing(importing) => { + importing.shutdown().await; + } + TimelineOrOffloaded::Offloaded(_offloaded) => { + // Nothing to shut down in this case + } } tenant.gc_block.before_delete(&timeline.timeline_id()); @@ -389,10 +403,18 @@ impl DeleteTimelineFlow { Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))? }); - // Offloaded timelines have no local state - // TODO: once we persist offloaded information, delete the timeline from there, too - if let TimelineOrOffloaded::Timeline(timeline) = timeline { - delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await; + match timeline { + TimelineOrOffloaded::Timeline(timeline) => { + delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await; + } + TimelineOrOffloaded::Importing(importing) => { + delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline) + .await; + } + TimelineOrOffloaded::Offloaded(_offloaded) => { + // Offloaded timelines have no local state + // TODO: once we persist offloaded information, delete the timeline from there, too + } } fail::fail_point!("timeline-delete-after-rm", |_| { @@ -451,12 +473,16 @@ pub(super) fn make_timeline_delete_guard( // For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346` let timelines = tenant.timelines.lock().unwrap(); let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap(); + let timelines_importing = tenant.timelines_importing.lock().unwrap(); let timeline = match timelines.get(&timeline_id) { Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)), None => match timelines_offloaded.get(&timeline_id) { Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)), - None => return Err(DeleteTimelineError::NotFound), + None => match timelines_importing.get(&timeline_id) { + Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)), + None => return Err(DeleteTimelineError::NotFound), + }, }, }; diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 658d867c18..606ad09ef1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -8,8 +8,10 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::info; use utils::lsn::Lsn; +use utils::pausable_failpoint; +use utils::sync::gate::Gate; -use super::Timeline; +use super::{Timeline, TimelineDeleteProgress}; use crate::context::RequestContext; use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient}; use crate::tenant::metadata::TimelineMetadata; @@ -19,14 +21,25 @@ mod importbucket_client; mod importbucket_format; pub(crate) mod index_part_format; -pub(crate) struct ImportingTimeline { +pub struct ImportingTimeline { pub import_task_handle: JoinHandle<()>, + pub import_task_gate: Gate, pub timeline: Arc, + pub delete_progress: TimelineDeleteProgress, +} + +impl std::fmt::Debug for ImportingTimeline { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id) + } } impl ImportingTimeline { - pub(crate) fn shutdown(self) { + pub async fn shutdown(&self) { self.import_task_handle.abort(); + self.import_task_gate.close().await; + + self.timeline.remote_client.shutdown().await; } } @@ -93,6 +106,15 @@ pub async fn doit( ); } + tracing::info!("Import plan executed. Flushing remote changes and notifying storcon"); + + timeline + .remote_client + .schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + + pausable_failpoint!("import-timeline-pre-success-notify-pausable"); + // Communicate that shard is done. // Ensure at-least-once delivery of the upcall to storage controller // before we mark the task as done and never come here again. @@ -179,8 +201,8 @@ async fn prepare_import( .await; match res { Ok(_) => break, - Err(err) => { - info!(?err, "indefinitely waiting for pgdata to finish"); + Err(_err) => { + info!("indefinitely waiting for pgdata to finish"); if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled()) .await .is_ok() diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 3e10a4e6d6..2ec9d86720 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -11,25 +11,14 @@ //! - => S3 as the source for the PGDATA instead of local filesystem //! //! TODOs before productionization: -//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding. -//! => produced image layers likely too small. //! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size. -//! - asserts / unwraps need to be replaced with errors -//! - don't trust remote objects will be small (=prevent OOMs in those cases) -//! - limit all in-memory buffers in size, or download to disk and read from there -//! - limit task concurrency -//! - generally play nice with other tenants in the system -//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits -//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc -//! - integrate with layer eviction system -//! - audit for Tenant::cancel nor Timeline::cancel responsivity -//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!) //! //! An incomplete set of TODOs from the Hackathon: //! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest) use std::collections::HashSet; use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; use std::ops::Range; use std::sync::Arc; @@ -43,7 +32,7 @@ use pageserver_api::key::{ rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, slru_segment_size_to_key, }; -use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range}; +use pageserver_api::keyspace::{ShardedRange, singleton_range}; use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus}; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; @@ -100,8 +89,24 @@ async fn run_v1( tasks: Vec::default(), }; - let import_config = &timeline.conf.timeline_import_config; - let plan = planner.plan(import_config).await?; + // Use the job size limit encoded in the progress if we are resuming an import. + // This ensures that imports have stable plans even if the pageserver config changes. + let import_config = { + match &import_progress { + Some(progress) => { + let base = &timeline.conf.timeline_import_config; + TimelineImportConfig { + import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit) + .unwrap(), + import_job_concurrency: base.import_job_concurrency, + import_job_checkpoint_threshold: base.import_job_checkpoint_threshold, + } + } + None => timeline.conf.timeline_import_config.clone(), + } + }; + + let plan = planner.plan(&import_config).await?; // Hash the plan and compare with the hash of the plan we got back from the storage controller. // If the two match, it means that the planning stage had the same output. @@ -113,20 +118,28 @@ async fn run_v1( let plan_hash = hasher.finish(); if let Some(progress) = &import_progress { - if plan_hash != progress.import_plan_hash { - anyhow::bail!("Import plan does not match storcon metadata"); - } - // Handle collisions on jobs of unequal length if progress.jobs != plan.jobs.len() { anyhow::bail!("Import plan job length does not match storcon metadata") } + + if plan_hash != progress.import_plan_hash { + anyhow::bail!("Import plan does not match storcon metadata"); + } } pausable_failpoint!("import-timeline-pre-execute-pausable"); + let jobs_count = import_progress.as_ref().map(|p| p.jobs); let start_from_job_idx = import_progress.map(|progress| progress.completed); - plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx) + + tracing::info!( + start_from_job_idx=?start_from_job_idx, + jobs=?jobs_count, + "Executing import plan" + ); + + plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx) .await } @@ -150,6 +163,7 @@ impl Planner { /// This function is and must remain pure: given the same input, it will generate the same import plan. async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result { let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align(); + anyhow::ensure!(pgdata_lsn.is_valid()); let datadir = PgDataDir::new(&self.storage).await?; @@ -218,15 +232,36 @@ impl Planner { checkpoint_buf, ))); + // Sort the tasks by the key ranges they handle. + // The plan being generated here needs to be stable across invocations + // of this method. + self.tasks.sort_by_key(|task| match task { + AnyImportTask::SingleKey(key) => (key.key, key.key.next()), + AnyImportTask::RelBlocks(rel_blocks) => { + (rel_blocks.key_range.start, rel_blocks.key_range.end) + } + AnyImportTask::SlruBlocks(slru_blocks) => { + (slru_blocks.key_range.start, slru_blocks.key_range.end) + } + }); + // Assigns parts of key space to later parallel jobs + // Note: The image layers produced here may have gaps, meaning, + // there is not an image for each key in the layer's key range. + // The read path stops traversal at the first image layer, regardless + // of whether a base image has been found for a key or not. + // (Concept of sparse image layers doesn't exist.) + // This behavior is exactly right for the base image layers we're producing here. + // But, since no other place in the code currently produces image layers with gaps, + // it seems noteworthy. let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); let mut current_chunk_size: usize = 0; let mut jobs = Vec::new(); for task in std::mem::take(&mut self.tasks).into_iter() { - if current_chunk_size + task.total_size() - > import_config.import_job_soft_size_limit.into() - { + let task_size = task.total_size(&self.shard); + let projected_chunk_size = current_chunk_size.saturating_add(task_size); + if projected_chunk_size > import_config.import_job_soft_size_limit.into() { let key_range = last_end_key..task.key_range().start; jobs.push(ChunkProcessingJob::new( key_range.clone(), @@ -236,7 +271,7 @@ impl Planner { last_end_key = key_range.end; current_chunk_size = 0; } - current_chunk_size += task.total_size(); + current_chunk_size = current_chunk_size.saturating_add(task_size); current_chunk.push(task); } jobs.push(ChunkProcessingJob::new( @@ -426,6 +461,8 @@ impl Plan { })); }, maybe_complete_job_idx = work.next() => { + pausable_failpoint!("import-task-complete-pausable"); + match maybe_complete_job_idx { Some(Ok((job_idx, res))) => { assert!(last_completed_job_idx.checked_add(1).unwrap() == job_idx); @@ -434,12 +471,18 @@ impl Plan { last_completed_job_idx = job_idx; if last_completed_job_idx % checkpoint_every == 0 { + tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status"); + let progress = ShardImportProgressV1 { jobs: jobs_in_plan, completed: last_completed_job_idx, import_plan_hash, + job_soft_size_limit: import_config.import_job_soft_size_limit.into(), }; + timeline.remote_client.schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + storcon_client.put_timeline_import_status( timeline.tenant_shard_id, timeline.timeline_id, @@ -568,18 +611,18 @@ impl PgDataDirDb { }; let path = datadir_path.join(rel_tag.to_segfile_name(segno)); - assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error + anyhow::ensure!(filesize % BLCKSZ as usize == 0); let nblocks = filesize / BLCKSZ as usize; - PgDataDirDbFile { + Ok(PgDataDirDbFile { path, filesize, rel_tag, segno, nblocks: Some(nblocks), // first non-cummulative sizes - } + }) }) - .collect(); + .collect::>()?; // Set cummulative sizes. Do all of that math here, so that later we could easier // parallelize over segments and know with which segments we need to write relsize @@ -614,12 +657,22 @@ impl PgDataDirDb { trait ImportTask { fn key_range(&self) -> Range; - fn total_size(&self) -> usize { - // TODO: revisit this - if is_contiguous_range(&self.key_range()) { - contiguous_range_len(&self.key_range()) as usize * 8192 + fn total_size(&self, shard_identity: &ShardIdentity) -> usize { + let range = ShardedRange::new(self.key_range(), shard_identity); + let page_count = range.page_count(); + if page_count == u32::MAX { + tracing::warn!( + "Import task has non contiguous key range: {}..{}", + self.key_range().start, + self.key_range().end + ); + + // Tasks should operate on contiguous ranges. It is unexpected for + // ranges to violate this assumption. Calling code handles this by mapping + // any task on a non contiguous range to its own image layer. + usize::MAX } else { - u32::MAX as usize + page_count as usize * 8192 } } @@ -640,7 +693,11 @@ impl Hash for ImportSingleKeyTask { let ImportSingleKeyTask { key, buf } = self; key.hash(state); - buf.hash(state); + // The key value might not have a stable binary representation. + // For instance, the db directory uses an unstable hash-map. + // To work around this we are a bit lax here and only hash the + // size of the buffer which must be consistent. + buf.len().hash(state); } } @@ -713,6 +770,8 @@ impl ImportTask for ImportRelBlocksTask { layer_writer: &mut ImageLayerWriter, ctx: &RequestContext, ) -> anyhow::Result { + const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024; + debug!("Importing relation file"); let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?; @@ -737,7 +796,7 @@ impl ImportTask for ImportRelBlocksTask { assert_eq!(key.len(), 1); assert!(!acc.is_empty()); assert!(acc_end > acc_start); - if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ { + if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE { acc.push(key.pop().unwrap()); Ok((acc, acc_start, end)) } else { @@ -752,8 +811,8 @@ impl ImportTask for ImportRelBlocksTask { .get_range(&self.path, range_start.into_u64(), range_end.into_u64()) .await?; let mut buf = Bytes::from(range_buf); - // TODO: batched writes for key in keys { + // The writer buffers writes internally let image = buf.split_to(8192); layer_writer.put_image(key, image, ctx).await?; nimages += 1; @@ -806,6 +865,9 @@ impl ImportTask for ImportSlruBlocksTask { debug!("Importing SLRU segment file {}", self.path); let buf = self.storage.get(&self.path).await?; + // TODO(vlad): Does timestamp to LSN work for imported timelines? + // Probably not since we don't append the `xact_time` to it as in + // [`WalIngest::ingest_xact_record`]. let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?; let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?; let mut blknum = start_blk; @@ -915,7 +977,7 @@ impl ChunkProcessingJob { let guard = timeline.layers.read().await; let existing_layer = guard.try_get_from_key(&desc.key()); if let Some(layer) = existing_layer { - if layer.metadata().generation != timeline.generation { + if layer.metadata().generation == timeline.generation { return Err(anyhow::anyhow!( "Import attempted to rewrite layer file in the same generation: {}", layer.local_path() @@ -942,6 +1004,15 @@ impl ChunkProcessingJob { .cloned(); match existing_layer { Some(existing) => { + // Unlink the remote layer from the index without scheduling its deletion. + // When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from + // remote storage, but that assumes that the layer was unlinked from the index first. + timeline + .remote_client + .schedule_unlinking_of_layers_from_index_part(std::iter::once( + existing.layer_desc().layer_name(), + ))?; + guard.open_mut()?.rewrite_layers( &[(existing.clone(), resident_layer.clone())], &[], diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index 34313748b7..bf2d9875c1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -6,7 +6,7 @@ use bytes::Bytes; use postgres_ffi::ControlFileData; use remote_storage::{ Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, - ListingObject, RemotePath, + ListingObject, RemotePath, RemoteStorageConfig, }; use serde::de::DeserializeOwned; use tokio_util::sync::CancellationToken; @@ -22,11 +22,9 @@ pub async fn new( location: &index_part_format::Location, cancel: CancellationToken, ) -> Result { - // FIXME: we probably want some timeout, and we might be able to assume the max file - // size on S3 is 1GiB (postgres segment size). But the problem is that the individual - // downloaders don't know enough about concurrent downloads to make a guess on the - // expected bandwidth and resulting best timeout. - let timeout = std::time::Duration::from_secs(24 * 60 * 60); + // Downloads should be reasonably sized. We do ranged reads for relblock raw data + // and full reads for SLRU segments which are bounded by Postgres. + let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT; let location_storage = match location { #[cfg(feature = "testing")] index_part_format::Location::LocalFs { path } => { @@ -50,9 +48,12 @@ pub async fn new( .import_pgdata_aws_endpoint_url .clone() .map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env - concurrency_limit: 100.try_into().unwrap(), // TODO: think about this - max_keys_per_list_response: Some(1000), // TODO: think about this - upload_storage_class: None, // irrelevant + // This matches the default import job concurrency. This is managed + // separately from the usual S3 client, but the concern here is bandwidth + // usage. + concurrency_limit: 128.try_into().unwrap(), + max_keys_per_list_response: Some(1000), + upload_storage_class: None, // irrelevant }, timeout, ) diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 0f73eb839b..633c94a010 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -113,7 +113,7 @@ impl WalReceiver { } connection_manager_state.shutdown().await; *loop_status.write().unwrap() = None; - debug!("task exits"); + info!("task exits"); } .instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id)) }); diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 52259f205b..249849ac4b 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection( let mut expected_wal_start = startpoint; while let Some(replication_message) = { select! { + biased; _ = cancellation.cancelled() => { debug!("walreceiver interrupted"); None diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 3befb42030..f42103c7cd 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -155,8 +155,9 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) int written = 0; written = snprintf((char *) &sk->conninfo, MAXCONNINFO, - "host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", - sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant); + "%s host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", + wp->config->safekeeper_conninfo_options, sk->host, sk->port, + wp->config->neon_timeline, wp->config->neon_tenant); if (written > MAXCONNINFO || written < 0) wp_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port); } diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 83ef72d3d7..cca20e746b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -714,6 +714,9 @@ typedef struct WalProposerConfig */ char *safekeepers_list; + /* libpq connection info options. */ + char *safekeeper_conninfo_options; + /* * WalProposer reconnects to offline safekeepers once in this interval. * Time is in milliseconds. diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 17582405db..d15bf91d24 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -64,6 +64,7 @@ char *wal_acceptors_list = ""; int wal_acceptor_reconnect_timeout = 1000; int wal_acceptor_connection_timeout = 10000; int safekeeper_proto_version = 3; +char *safekeeper_conninfo_options = ""; /* Set to true in the walproposer bgw. */ static bool am_walproposer; @@ -119,6 +120,7 @@ init_walprop_config(bool syncSafekeepers) walprop_config.neon_timeline = neon_timeline; /* WalProposerCreate scribbles directly on it, so pstrdup */ walprop_config.safekeepers_list = pstrdup(wal_acceptors_list); + walprop_config.safekeeper_conninfo_options = pstrdup(safekeeper_conninfo_options); walprop_config.safekeeper_reconnect_timeout = wal_acceptor_reconnect_timeout; walprop_config.safekeeper_connection_timeout = wal_acceptor_connection_timeout; walprop_config.wal_segment_size = wal_segment_size; @@ -203,6 +205,16 @@ nwp_register_gucs(void) * GUC_LIST_QUOTE */ NULL, assign_neon_safekeepers, NULL); + DefineCustomStringVariable( + "neon.safekeeper_conninfo_options", + "libpq keyword parameters and values to apply to safekeeper connections", + NULL, + &safekeeper_conninfo_options, + "", + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + DefineCustomIntVariable( "neon.safekeeper_reconnect_timeout", "Walproposer reconnects to offline safekeepers once in this interval.", diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 5e494dfdd6..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,35 +17,23 @@ pub(super) async fn authenticate( config: &'static AuthenticationConfig, secret: AuthSecret, ) -> auth::Result { - let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] AuthSecret::Md5(_) => { debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::bad_auth_method("MD5")); + return Err(auth::AuthError::MalformedPassword("MD5 not supported")); } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, - async { - - flow.begin(scram).await.map_err(|error| { - warn!(?error, "error sending scram acknowledgement"); - error - })?.authenticate().await.map_err(|error| { - warn!(?error, "error processing scram messages"); - error - }) - } + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), ) .await - .map_err(|e| { - warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::AuthError::user_timeout(e) - })??; + .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index dd48384c03..a50c30257f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -2,7 +2,6 @@ use std::fmt; use async_trait::async_trait; use postgres_client::config::SslMode; -use pq_proto::BeMessage as Be; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; @@ -16,6 +15,7 @@ use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::stream::PqStream; @@ -154,11 +154,13 @@ async fn authenticate( // Give user a URL to spawn a new database. info!(parent: &span, "sending the auth URL to the user"); - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::NoticeResponse(&greeting)) - .await?; + client.write_message(BeMessage::AuthenticationOk); + client.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + client.write_message(BeMessage::NoticeResponse(&greeting)); + client.flush().await?; // Wait for console response via control plane (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); @@ -188,7 +190,7 @@ async fn authenticate( } } - client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; + client.write_message(BeMessage::NoticeResponse("Connecting to database.")); // This config should be self-contained, because we won't // take username or dbname from client's startup message. diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 3316543022..1e5c076fb9 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext( debug!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - // pause the timer while we communicate with the client - let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let ep = EndpointIdInt::from(&info.endpoint); - let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword { + let auth_flow = AuthFlow::new( + client, + auth::CleartextPassword { secret, endpoint: ep, pool: config.thread_pool.clone(), - }) - .await?; - drop(paused); - // cleartext auth is only allowed to the ws/http protocol. - // If we're here, we already received the password in the first message. - // Scram protocol will be executed on the proxy side. - let auth_outcome = auth_flow.authenticate().await?; + }, + ); + let auth_outcome = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // cleartext auth is only allowed to the ws/http protocol. + // If we're here, we already received the password in the first message. + // Scram protocol will be executed on the proxy side. + auth_flow.authenticate().await? + }; let keys = match auth_outcome { sasl::Outcome::Success(key) => key, @@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication( // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let payload = AuthFlow::new(client) - .begin(auth::PasswordHack) - .await? + let payload = AuthFlow::new(client, auth::PasswordHack) .get_password() .await?; diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 6e5c0a3954..735cb52f47 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,37 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; -use crate::protocol2::ConnectionInfoExtra; +use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -200,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -284,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -300,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -368,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -402,7 +290,7 @@ async fn authenticate_with_secret( }; // we have authenticated the password - client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?; + client.write_message(BeMessage::AuthenticationOk); return Ok(ComputeCredentials { info, keys }); } @@ -438,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -447,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -474,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -540,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -553,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -577,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -635,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -654,55 +527,10 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -784,7 +612,7 @@ mod tests { #[tokio::test] async fn auth_quirks_cleartext() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -838,7 +666,7 @@ mod tests { #[tokio::test] async fn auth_quirks_password_hack() { let (mut client, server) = tokio::io::duplex(1024); - let mut stream = PqStream::new(Stream::from_raw(server)); + let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server)); let ctx = RequestContext::test(); let api = Auth { @@ -887,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 526d0df7f2..b51da48862 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -5,7 +5,6 @@ use std::net::IpAddr; use std::str::FromStr; use itertools::Itertools; -use pq_proto::StartupMessageParams; use thiserror::Error; use tracing::{debug, warn}; @@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::NeonOptions; use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI}; use crate::types::{EndpointId, RoleName}; diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 0992c6d875..8fbc4577e9 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,10 +1,8 @@ //! Main authentication flow. -use std::io; use std::sync::Arc; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; -use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; use crate::intern::EndpointIdInt; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; use crate::scram::{self}; use crate::stream::{PqStream, Stream}; use crate::tls::TlsServerEndPoint; -/// Every authentication selector is supposed to implement this trait. -pub(crate) trait AuthMethod { - /// Any authentication selector should provide initial backend message - /// containing auth method name and parameters, e.g. md5 salt. - fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; -} - -/// Initial state of [`AuthFlow`]. -pub(crate) struct Begin; - /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. pub(crate) struct Scram<'a>( pub(crate) &'a scram::ServerSecret, pub(crate) &'a RequestContext, ); -impl AuthMethod for Scram<'_> { +impl Scram<'_> { #[inline(always)] fn first_message(&self, channel_binding: bool) -> BeMessage<'_> { if channel_binding { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS)) } else { - Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods( scram::METHODS_WITHOUT_PLUS, )) } @@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> { /// . pub(crate) struct PasswordHack; -impl AuthMethod for PasswordHack { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// Use clear-text password auth called `password` in docs /// pub(crate) struct CleartextPassword { @@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword { pub(crate) secret: AuthSecret, } -impl AuthMethod for CleartextPassword { - #[inline(always)] - fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> { - Be::AuthenticationCleartextPassword - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, - /// State might contain ancillary data (see [`Self::begin`]). + /// State might contain ancillary data. state: State, tls_server_end_point: TlsServerEndPoint, } /// Initial state of the stream wrapper. -impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { +impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> { /// Create a new wrapper for client authentication. - pub(crate) fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>, method: M) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { stream, - state: Begin, + state: method, tls_server_end_point, } } - - /// Move to the next step by sending auth method's name & params to client. - pub(crate) async fn begin(self, method: M) -> io::Result> { - self.stream - .write_message(&method.first_message(self.tls_server_end_point.supported())) - .await?; - - Ok(AuthFlow { - stream: self.stream, - state: method, - tls_server_end_point: self.tls_server_end_point, - }) - } } impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn get_password(self) -> super::Result { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -133,6 +99,10 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { + self.stream + .write_message(BeMessage::AuthenticationCleartextPassword); + self.stream.flush().await?; + let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -147,7 +117,7 @@ impl AuthFlow<'_, S, CleartextPassword> { .await?; if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; + self.stream.write_message(BeMessage::AuthenticationOk); } Ok(outcome) @@ -159,42 +129,36 @@ impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; + let channel_binding = self.tls_server_end_point; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + // send sasl message. + { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // Initial client message contains the chosen auth method's name. - let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg) - .ok_or(AuthError::MalformedPassword("bad sasl message"))?; - - // Currently, the only supported SASL method is SCRAM. - if !scram::METHODS.contains(&sasl.method) { - return Err(super::AuthError::bad_auth_method(sasl.method)); + let sasl = self.state.first_message(channel_binding.supported()); + self.stream.write_message(sasl); + self.stream.flush().await?; } - match sasl.method { - SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), - SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus), - _ => {} - } + // complete sasl handshake. + sasl::authenticate(ctx, self.stream, |method| { + // Currently, the only supported SASL method is SCRAM. + match method { + SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus); + } + method => return Err(sasl::Error::BadAuthMethod(method.into())), + } - // TODO: make this a metric instead - info!("client chooses {}", sasl.method); + // TODO: make this a metric instead + info!("client chooses {}", method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) - .authenticate(scram::Exchange::new( - secret, - rand::random, - self.tls_server_end_point, - )) - .await?; - - if let sasl::Outcome::Success(_) = &outcome { - self.stream.write_message_noflush(&Be::AuthenticationOk)?; - } - - Ok(outcome) + Ok(scram::Exchange::new(secret, rand::random, channel_binding)) + }) + .await + .map_err(AuthError::Sasl) } } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..a4f517fead 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -4,8 +4,9 @@ //! This allows connecting to pods/services running in the same Kubernetes cluster from //! the outside. Similar to an ingress controller for HTTPS. +use std::net::SocketAddr; use std::path::Path; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use anyhow::{Context, anyhow, bail, ensure}; use clap::Arg; @@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpListener; use tokio_rustls::TlsConnector; +use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; use tracing::{Instrument, error, info}; use utils::project_git_version; @@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry; use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled}; +use crate::proxy::{ + ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, +}; use crate::stream::{PqStream, Stream}; -use crate::tls::TlsServerEndPoint; project_git_version!(GIT_VERSION); @@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> { .parse()?; // Configure TLS - let (tls_config, tls_server_end_point): (Arc, TlsServerEndPoint) = match ( + let tls_config = match ( args.get_one::("tls-key"), args.get_one::("tls-cert"), ) { @@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, proxy_listener, cancellation_token.clone(), )) @@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(compute_tls_config), - tls_server_end_point, proxy_listener_compute_tls, cancellation_token.clone(), )) @@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> { pub(super) fn parse_tls( key_path: &Path, cert_path: &Path, -) -> anyhow::Result<(Arc, TlsServerEndPoint)> { +) -> anyhow::Result> { let key = { let key_bytes = std::fs::read(key_path).context("TLS key file")?; @@ -187,10 +189,6 @@ pub(super) fn parse_tls( })? }; - // needed for channel bindings - let first_cert = cert_chain.first().context("missing certificate")?; - let tls_server_end_point = TlsServerEndPoint::new(first_cert)?; - let tls_config = rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider())) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) @@ -199,14 +197,13 @@ pub(super) fn parse_tls( .with_single_cert(cert_chain, key)? .into(); - Ok((tls_config, tls_server_end_point)) + Ok(tls_config) } pub(super) async fn task_main( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -242,15 +239,7 @@ pub(super) async fn task_main( crate::metrics::Protocol::SniRouter, "sni", ); - handle_client( - ctx, - dest_suffix, - tls_config, - compute_tls_config, - tls_server_end_point, - socket, - ) - .await + handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } .unwrap_or_else(|e| { // Acknowledge that the task has finished with an error. @@ -269,55 +258,26 @@ pub(super) async fn task_main( Ok(()) } -const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; - async fn ssl_handshake( ctx: &RequestContext, raw_stream: S, tls_config: Arc, - tls_server_end_point: TlsServerEndPoint, -) -> anyhow::Result> { - let mut stream = PqStream::new(Stream::from_raw(raw_stream)); - - let msg = stream.read_startup_packet().await?; - use pq_proto::FeStartupPacket::SslRequest; - +) -> anyhow::Result> { + let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?; match msg { - SslRequest { direct: false } => { - stream - .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) - .await?; + FeStartupPacket::SslRequest { direct: None } => { + let raw = stream.accept_tls().await?; - // Upgrade raw stream into a secure TLS-backed stream. - // NOTE: We've consumed `tls`; this fact will be used later. - - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empty. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. - if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); - } - - Ok(Stream::Tls { - tls: Box::new( - raw.upgrade(tls_config, !ctx.has_private_peer_addr()) - .await?, - ), - tls_server_end_point, - }) + Ok(raw + .upgrade(tls_config, !ctx.has_private_peer_addr()) + .await?) } unexpected => { info!( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream - .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None) - .await? + Err(stream.throw_error(TlsRequired, None).await)? } } } @@ -327,15 +287,18 @@ async fn handle_client( dest_suffix: Arc, tls_config: Arc, compute_tls_config: Option>, - tls_server_end_point: TlsServerEndPoint, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { - let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?; + let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?; // Cut off first part of the SNI domain // We receive required destination details in the format of // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain` - let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?; + let sni = tls_stream + .get_ref() + .1 + .server_name() + .ok_or(anyhow!("SNI missing"))?; let dest: Vec<&str> = sni .split_once('.') .context("invalid SNI")? diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 5f24940985..dcae263647 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -476,8 +452,7 @@ pub async fn run() -> anyhow::Result<()> { let key_path = args.tls_key.expect("already asserted it is set"); let cert_path = args.tls_cert.expect("already asserted it is set"); - let (tls_config, tls_server_end_point) = - super::pg_sni_router::parse_tls(&key_path, &cert_path)?; + let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?; let dest = Arc::new(dest); @@ -485,7 +460,6 @@ pub async fn run() -> anyhow::Result<()> { dest.clone(), tls_config.clone(), None, - tls_server_end_point, listen, cancellation_token.clone(), )); @@ -494,7 +468,6 @@ pub async fn run() -> anyhow::Result<()> { dest, tls_config, Some(config.connect_to_compute.tls.clone()), - tls_server_end_point, listen_tls, cancellation_token.clone(), )); @@ -681,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -5,7 +5,6 @@ use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; -use pq_proto::CancelKeyData; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -13,15 +12,15 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; -use crate::protocol2::ConnectionInfoExtra; +use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 26254beecf..2899f25129 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -8,7 +8,6 @@ use itertools::Itertools; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use pq_proto::StartupMessageParams; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..a97339df9a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -65,9 +64,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e3184e20d1..7fb84b5ee5 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; @@ -159,7 +159,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, @@ -221,12 +221,10 @@ pub(crate) async fn handle_client( .await { Ok(auth_result) => auth_result, - Err(e) => { - return stream.throw_error(e, Some(ctx)).await?; - } + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let mut node = connect_to_compute( + let node = connect_to_compute( ctx, &TcpMechanism { user_info, @@ -238,7 +236,7 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; let cancellation_handler_clone = Arc::clone(&cancellation_handler); @@ -246,14 +244,8 @@ pub(crate) async fn handle_client( session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; Ok(Some(ProxyPassthrough { client: stream, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -4,7 +4,6 @@ use std::net::IpAddr; use chrono::Utc; use once_cell::sync::OnceCell; -use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; @@ -20,6 +19,7 @@ use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting, }; +use crate::pqproto::StartupMessageParams; use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra}; use crate::types::{DbName, EndpointId, RoleName}; @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index f6250bcd17..c9d3905abd 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr; use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr}; use parquet::file::writer::SerializedFileWriter; use parquet::record::RecordWriter; -use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use serde::ser::SerializeMap; use tokio::sync::mpsc; @@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner}; use crate::config::remote_storage_from_toml; use crate::context::LOG_CHAN_DISCONNECT; use crate::ext::TaskExt; +use crate::pqproto::StartupMessageParams; #[derive(clap::Args, Clone, Debug)] pub struct ParquetUploadArgs { diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -15,7 +17,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +25,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,66 +67,41 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) - .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), - ]) - .build()?; + let response = { + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response }; - info!(duration = ?start.elapsed(), "received http response"); - let body = match parse_body::(response).await { + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -180,7 +156,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -216,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -268,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -313,225 +292,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } @@ -640,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d1f8430b8a..d65d056585 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -92,6 +92,7 @@ mod logging; mod metrics; mod parse; mod pglb; +mod pqproto; mod protocol2; mod proxy; mod rate_limiter; diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs new file mode 100644 index 0000000000..43074bf208 --- /dev/null +++ b/proxy/src/pqproto.rs @@ -0,0 +1,693 @@ +//! Postgres protocol codec +//! +//! + +use std::fmt; +use std::io::{self, Cursor}; + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use rand::distributions::{Distribution, Standard}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; + +pub type ErrorCode = [u8; 5]; + +pub const FE_PASSWORD_MESSAGE: u8 = b'p'; + +pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000"; + +/// The protocol version number. +/// +/// The most significant 16 bits are the major version number (3 for the protocol described here). +/// The least significant 16 bits are the minor version number (0 for the protocol described here). +/// +#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +pub struct ProtocolVersion { + major: big_endian::U16, + minor: big_endian::U16, +} + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self { + major: big_endian::U16::new(major), + minor: big_endian::U16::new(minor), + } + } + pub const fn minor(self) -> u16 { + self.minor.get() + } + pub const fn major(self) -> u16 { + self.major.get() + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + +/// read the type from the stream using zerocopy. +/// +/// not cancel safe. +macro_rules! read { + ($s:expr => $t:ty) => {{ + // cannot be implemented as a function due to lack of const-generic-expr + let mut buf = [0; size_of::<$t>()]; + $s.read_exact(&mut buf).await?; + let res: $t = zerocopy::transmute!(buf); + res + }}; +} + +pub async fn read_startup(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + /// + const MAX_STARTUP_PACKET_LENGTH: usize = 10000; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + /// This first reads the startup message header, is 8 bytes. + /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. + /// + /// The length value is inclusive of the header. For example, + /// an empty message will always have length 8. + #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] + #[repr(C)] + struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, + } + + let header = read!(stream => StartupHeader); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if header.as_bytes()[0] == 0x16 { + return Ok(FeStartupPacket::SslRequest { + // The bytes we read for the header are actually part of a TLS ClientHello. + // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here. + // In practice though, I see no world where a ClientHello is less than 8 bytes + // since it includes ephemeral keys etc. + direct: Some(zerocopy::transmute!(header)), + }); + } + + let Some(len) = (header.len.get() as usize).checked_sub(8) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 8.", + header.len, + ))); + }; + + // TODO: add a histogram for startup packet lengths + if len > MAX_STARTUP_PACKET_LENGTH { + tracing::warn!("large startup message detected: {len} bytes"); + return Err(io::Error::other(format!( + "invalid startup message length {len}" + ))); + } + + match header.version { + // + CANCEL_REQUEST_CODE => { + if len != 8 { + return Err(io::Error::other( + "CancelRequest message is malformed, backend PID / secret key missing", + )); + } + + Ok(FeStartupPacket::CancelRequest( + read!(stream => CancelKeyData), + )) + } + // + NEGOTIATE_SSL_CODE => { + // Requested upgrade to SSL (aka TLS) + Ok(FeStartupPacket::SslRequest { direct: None }) + } + NEGOTIATE_GSS_CODE => { + // Requested upgrade to GSSAPI + Ok(FeStartupPacket::GssEncRequest) + } + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other( + format!("Unrecognized request code {version:?}"), + )), + // StartupMessage + version => { + // The protocol version number is followed by one or more pairs of parameter name and value strings. + // A zero byte is required as a terminator after the last name/value pair. + // Parameters can appear in any order. user is required, others are optional. + + let mut buf = vec![0; len]; + stream.read_exact(&mut buf).await?; + + if buf.pop() != Some(b'\0') { + return Err(io::Error::other( + "StartupMessage params: missing null terminator", + )); + } + + // TODO: Don't do this. + // There's no guarantee that these messages are utf8, + // but they usually happen to be simple ascii. + let params = String::from_utf8(buf) + .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?; + + Ok(FeStartupPacket::StartupMessage { + version, + params: StartupMessageParams { params }, + }) + } + } +} + +/// Read a raw postgres packet, which will respect the max length requested. +/// +/// This returns the message tag, as well as the message body. The message +/// body is written into `buf`, and it is otherwise completely overwritten. +/// +/// This is not cancel safe. +pub async fn read_message<'a, S>( + stream: &mut S, + buf: &'a mut Vec, + max: u32, +) -> io::Result<(u8, &'a mut [u8])> +where + S: AsyncRead + Unpin, +{ + /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes. + /// The first byte is a message tag, and the next 4 bytes is a big-endian length. + /// + /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example, + /// an empty message will always have length 4. + #[derive(Clone, Copy, FromBytes)] + #[repr(C)] + struct Header { + tag: u8, + len: big_endian::U32, + } + + let header = read!(stream => Header); + + // as described above, the length must be at least 4. + let Some(len) = header.len.get().checked_sub(4) else { + return Err(io::Error::other(format!( + "invalid startup message length {}, must be at least 4.", + header.len, + ))); + }; + + // TODO: add a histogram for message lengths + + // check if the message exceeds our desired max. + if len > max { + tracing::warn!("large postgres message detected: {len} bytes"); + return Err(io::Error::other(format!("invalid message length {len}"))); + } + + // read in our entire message. + buf.resize(len as usize, 0); + stream.read_exact(buf).await?; + + Ok((header.tag, buf)) +} + +pub struct WriteBuf(Cursor>); + +impl Buf for WriteBuf { + #[inline] + fn remaining(&self) -> usize { + self.0.remaining() + } + + #[inline] + fn chunk(&self) -> &[u8] { + self.0.chunk() + } + + #[inline] + fn advance(&mut self, cnt: usize) { + self.0.advance(cnt); + } +} + +impl WriteBuf { + pub const fn new() -> Self { + Self(Cursor::new(Vec::new())) + } + + /// Use a heuristic to determine if we should shrink the write buffer. + #[inline] + fn should_shrink(&self) -> bool { + let n = self.0.position() as usize; + let len = self.0.get_ref().len(); + + // the unused space at the front of our buffer is 2x the size of our filled portion. + n + n > len + } + + /// Shrink the write buffer so that subsequent writes have more spare capacity. + #[cold] + fn shrink(&mut self) { + let n = self.0.position() as usize; + let buf = self.0.get_mut(); + + // buf repr: + // [----unused------|-----filled-----|-----uninit-----] + // ^ n ^ buf.len() ^ buf.capacity() + let filled = n..buf.len(); + let filled_len = filled.len(); + buf.copy_within(filled, 0); + buf.truncate(filled_len); + self.0.set_position(0); + } + + /// clear the write buffer. + pub fn reset(&mut self) { + let buf = self.0.get_mut(); + buf.clear(); + self.0.set_position(0); + } + + /// Write a raw message to the internal buffer. + /// + /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since + /// we calculate the length after the fact. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + if self.should_shrink() { + self.shrink(); + } + + let buf = self.0.get_mut(); + buf.reserve(5 + size_hint); + + buf.push(tag); + let start = buf.len(); + buf.extend_from_slice(&[0, 0, 0, 0]); + + f(buf); + + let end = buf.len(); + let len = (end - start) as u32; + buf[start..start + 4].copy_from_slice(&len.to_be_bytes()); + } + + /// Write an encryption response message. + pub fn encryption(&mut self, m: u8) { + self.0.get_mut().push(m); + } + + pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) { + self.shrink(); + + // + // + // "SERROR\0CXXXXX\0M\0\0".len() == 17 + self.write_raw(17 + msg.len(), b'E', |buf| { + // Severity: ERROR + buf.put_slice(b"SERROR\0"); + + // Code: error_code + buf.put_u8(b'C'); + buf.put_slice(&error_code); + buf.put_u8(0); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End. + buf.put_u8(0); + }); + } +} + +#[derive(Debug)] +pub enum FeStartupPacket { + CancelRequest(CancelKeyData), + SslRequest { + direct: Option<[u8; 8]>, + }, + GssEncRequest, + StartupMessage { + version: ProtocolVersion, + params: StartupMessageParams, + }, +} + +#[derive(Debug, Clone, Default)] +pub struct StartupMessageParams { + pub params: String, +} + +impl StartupMessageParams { + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.iter().find_map(|(k, v)| (k == name).then_some(v)) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + /// [`None`] means that there's no `options` in [`Self`]. + pub fn options_raw(&self) -> Option> { + self.get("options").map(Self::parse_options_raw) + } + + /// Split command-line options according to PostgreSQL's logic, + /// taking into account all escape sequences but leaving them as-is. + pub fn parse_options_raw(input: &str) -> impl Iterator { + // See `postgres: pg_split_opts`. + let mut last_was_escape = false; + input + .split(move |c: char| { + // We split by non-escaped whitespace symbols. + let should_split = c.is_ascii_whitespace() && !last_was_escape; + last_was_escape = c == '\\' && !last_was_escape; + should_split + }) + .filter(|s| !s.is_empty()) + } + + /// Iterate through key-value pairs in an arbitrary order. + pub fn iter(&self) -> impl Iterator { + self.params.split_terminator('\0').tuples() + } + + // This function is mostly useful in tests. + #[cfg(test)] + pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { + let mut b = Self { + params: String::new(), + }; + for (k, v) in pairs { + b.insert(k, v); + } + b + } + + /// Set parameter's value by its name. + /// name and value must not contain a \0 byte + pub fn insert(&mut self, name: &str, value: &str) { + self.params.reserve(name.len() + value.len() + 2); + self.params.push_str(name); + self.params.push('\0'); + self.params.push_str(value); + self.params.push('\0'); + } +} + +/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just +/// opaque bytes. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)] +pub struct CancelKeyData(pub big_endian::U64); + +pub fn id_to_cancel_key(id: u64) -> CancelKeyData { + CancelKeyData(big_endian::U64::new(id)) +} + +impl fmt::Display for CancelKeyData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let id = self.0; + f.debug_tuple("CancelKeyData") + .field(&format_args!("{id:x}")) + .finish() + } +} +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> CancelKeyData { + id_to_cancel_key(rng.r#gen()) + } +} + +pub enum BeMessage<'a> { + AuthenticationOk, + AuthenticationSasl(BeAuthenticationSaslMessage<'a>), + AuthenticationCleartextPassword, + BackendKeyData(CancelKeyData), + ParameterStatus { + name: &'a [u8], + value: &'a [u8], + }, + ReadyForQuery, + NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, +} + +#[derive(Debug)] +pub enum BeAuthenticationSaslMessage<'a> { + Methods(&'a [&'a str]), + Continue(&'a [u8]), + Final(&'a [u8]), +} + +impl BeMessage<'_> { + /// Write the message into an internal buffer + pub fn write_message(self, buf: &mut WriteBuf) { + match self { + // + BeMessage::AuthenticationOk => { + buf.write_raw(1, b'R', |buf| buf.put_i32(0)); + } + // + BeMessage::AuthenticationCleartextPassword => { + buf.write_raw(1, b'R', |buf| buf.put_i32(3)); + } + + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => { + let len: usize = methods.iter().map(|m| m.len() + 1).sum(); + buf.write_raw(len + 2, b'R', |buf| { + buf.put_i32(10); // Specifies that SASL auth method is used. + for method in methods { + buf.put_slice(method.as_bytes()); + buf.put_u8(0); + } + buf.put_u8(0); // zero terminator for the list + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(11); // Continue SASL auth. + buf.put_slice(extra); + }); + } + // + BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => { + buf.write_raw(extra.len() + 1, b'R', |buf| { + buf.put_i32(12); // Send final SASL message. + buf.put_slice(extra); + }); + } + + // + BeMessage::BackendKeyData(key_data) => { + buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes())); + } + + // + // + BeMessage::NoticeResponse(msg) => { + // 'N' signalizes NoticeResponse messages + buf.write_raw(18 + msg.len(), b'N', |buf| { + // Severity: NOTICE + buf.put_slice(b"SNOTICE\0"); + + // Code: XX000 (ignored for notice, but still required) + buf.put_slice(b"CXX000\0"); + + // Message: msg + buf.put_u8(b'M'); + buf.put_slice(msg.as_bytes()); + buf.put_u8(0); + + // End notice. + buf.put_u8(0); + }); + } + + // + BeMessage::ParameterStatus { name, value } => { + buf.write_raw(name.len() + value.len() + 2, b'S', |buf| { + buf.put_slice(name.as_bytes()); + buf.put_u8(0); + buf.put_slice(value.as_bytes()); + buf.put_u8(0); + }); + } + + // + BeMessage::ReadyForQuery => { + buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I')); + } + + // + BeMessage::NegotiateProtocolVersion { version, options } => { + let len: usize = options.iter().map(|o| o.len() + 1).sum(); + buf.write_raw(8 + len, b'v', |buf| { + buf.put_slice(version.as_bytes()); + buf.put_u32(options.len() as u32); + for option in options { + buf.put_slice(option.as_bytes()); + buf.put_u8(0); + } + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use tokio::io::{AsyncWriteExt, duplex}; + use zerocopy::IntoBytes; + + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; + + use super::ProtocolVersion; + + #[tokio::test] + async fn reject_large_startup() { + // we're going to define a v3.0 startup message with far too many parameters. + let mut payload = vec![]; + // 10001 + 8 bytes. + payload.extend_from_slice(&10009_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.resize(10009, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_startup(&mut server).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid startup message length 10001"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn reject_large_password() { + // we're going to define a password message that is far too long. + let mut payload = vec![]; + payload.push(b'p'); + payload.extend_from_slice(&517_u32.to_be_bytes()); + payload.resize(518, b'a'); + + let (mut server, mut client) = duplex(128); + #[rustfmt::skip] + let (server, client) = tokio::join!( + async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() }, + async move { client.write_all(&payload).await.unwrap_err() }, + ); + + assert_eq!(server.to_string(), "invalid message length 513"); + assert_eq!(client.to_string(), "broken pipe"); + } + + #[tokio::test] + async fn read_startup_message() { + let mut payload = vec![]; + payload.extend_from_slice(&17_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes()); + payload.extend_from_slice(b"abc\0def\0\0"); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::StartupMessage { version, params } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + + assert_eq!(version.major(), 3); + assert_eq!(version.minor(), 0); + assert_eq!(params.params, "abc\0def\0"); + } + + #[tokio::test] + async fn read_ssl_message() { + let mut payload = vec![]; + payload.extend_from_slice(&8_u32.to_be_bytes()); + payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes()); + + let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap(); + let FeStartupPacket::SslRequest { direct: None } = startup else { + panic!("unexpected startup message: {startup:?}"); + }; + } + + #[tokio::test] + async fn read_tls_message() { + // sample client hello taken from + let client_hello = [ + 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, + 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, + 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, + 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, + 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, + 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, + 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19, + 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e, + 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09, + 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01, + 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72, + 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, + 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, + 0x54, + ]; + + let mut cursor = Cursor::new(&client_hello); + + let startup = read_startup(&mut cursor).await.unwrap(); + let FeStartupPacket::SslRequest { + direct: Some(prefix), + } = startup + else { + panic!("unexpected startup message: {startup:?}"); + }; + + // check that no data is lost. + assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]); + assert_eq!(cursor.position(), 8); + } + + #[tokio::test] + async fn read_message_success() { + let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2"; + let mut cursor = Cursor::new(&query); + + let mut buf = vec![]; + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 1"); + + let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap(); + assert_eq!(tag, b'Q'); + assert_eq!(message, b"SELECT 2"); + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index e013fbbe2e..57785c9ec5 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,5 +1,4 @@ use async_trait::async_trait; -use pq_proto::StartupMessageParams; use tokio::time; use tracing::{debug, info, warn}; @@ -15,6 +14,7 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; +use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, retry_after, should_retry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 54c02f2c15..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,8 +1,4 @@ -use bytes::Buf; -use pq_proto::framed::Framed; -use pq_proto::{ - BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, -}; +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -12,7 +8,10 @@ use crate::config::TlsConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::ERR_INSECURE_CONNECTION; +use crate::pqproto::{ + BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams, +}; +use crate::proxy::TlsRequired; use crate::stream::{PqStream, Stream, StreamUpgradeError}; use crate::tls::PG_ALPN_PROTOCOL; @@ -59,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -71,33 +70,25 @@ pub(crate) async fn handshake( const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); - let mut stream = PqStream::new(Stream::from_raw(stream)); + let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?; loop { - let msg = stream.read_startup_packet().await?; match msg { FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; - // We can't perform TLS handshake without a config - let have_tls = tls.is_some(); - if !direct { - stream - .write_message(&Be::EncryptionResponse(have_tls)) - .await?; - } else if !have_tls { - return Err(HandshakeError::ProtocolViolation); - } - if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let Framed { - stream: raw, - read_buf, - write_buf, - } = stream.framed; + let mut read_buf; + let raw = if let Some(direct) = &direct { + read_buf = &direct[..]; + stream.accept_direct_tls() + } else { + read_buf = &[]; + stream.accept_tls().await? + }; let Stream::Raw { raw } = raw else { return Err(HandshakeError::StreamUpgradeError( @@ -105,12 +96,11 @@ pub(crate) async fn handshake( )); }; - let mut read_buf = read_buf.reader(); let mut res = Ok(()); let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone()) .accept_with(raw, |session| { // push the early data to the tls session - while !read_buf.get_ref().is_empty() { + while !read_buf.is_empty() { match session.read_tls(&mut read_buf) { Ok(_) => {} Err(e) => { @@ -119,11 +109,12 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; - let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } @@ -157,16 +148,17 @@ pub(crate) async fn handshake( let (_, tls_server_end_point) = tls.cert_resolver.resolve(conn_info.server_name()); - stream = PqStream { - framed: Framed { - stream: Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }, - read_buf, - write_buf, - }, + let tls = Stream::Tls { + tls: tls_stream, + tls_server_end_point, }; + (stream, msg) = PqStream::parse_startup(tls).await?; + } else { + if direct.is_some() { + // client sent us a ClientHello already, we can't do anything with it. + return Err(HandshakeError::ProtocolViolation); + } + msg = stream.reject_encryption().await?; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -176,7 +168,7 @@ pub(crate) async fn handshake( tried_gss = true; // Currently, we don't support GSSAPI - stream.write_message(&Be::EncryptionResponse(false)).await?; + msg = stream.reject_encryption().await?; } _ => return Err(HandshakeError::ProtocolViolation), }, @@ -186,13 +178,7 @@ pub(crate) async fn handshake( // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - return stream - .throw_error_str( - ERR_INSECURE_CONNECTION, - crate::error::ErrorKind::User, - None, - ) - .await?; + Err(stream.throw_error(TlsRequired, None).await)?; } // This log highlights the start of the connection. @@ -214,20 +200,21 @@ pub(crate) async fn handshake( // no protocol extensions are supported. // let mut unsupported = vec![]; - for (k, _) in params.iter() { + let mut supported = StartupMessageParams::default(); + + for (k, v) in params.iter() { if k.starts_with("_pq_.") { unsupported.push(k); + } else { + supported.insert(k, v); } } - // TODO: remove unsupported options so we don't send them to compute. - - stream - .write_message(&Be::NegotiateProtocolVersion { - version: PG_PROTOCOL_LATEST, - options: &unsupported, - }) - .await?; + stream.write_message(BeMessage::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }); + stream.flush().await?; info!( ?version, @@ -235,7 +222,7 @@ pub(crate) async fn handshake( session_type = "normal", "successful handshake; unsupported minor version requested" ); - break Ok(HandshakeData::Startup(stream, params)); + break Ok(HandshakeData::Startup(stream, supported)); } FeStartupPacket::StartupMessage { version, params } => { warn!( diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0a86022e78..0ffc54aa88 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,15 +10,14 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, ToSmolStr, format_smolstr}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, warn}; @@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; -use crate::error::ReportableError; +use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; +use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; @@ -38,6 +38,18 @@ use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; +#[derive(Error, Debug)] +#[error("{ERR_INSECURE_CONNECTION}")] +pub struct TlsRequired; + +impl ReportableError for TlsRequired { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + +impl UserFacingError for TlsRequired {} + pub async fn run_until_cancelled( f: F, cancellation_token: &CancellationToken, @@ -258,7 +270,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, @@ -329,11 +341,11 @@ pub(crate) async fn handle_client( let user_info = match result { Ok(user_info) => user_info, - Err(e) => stream.throw_error(e, Some(ctx)).await?, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -349,10 +361,10 @@ pub(crate) async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream + return Err(stream .throw_error(e, Some(ctx)) .instrument(params_span) - .await?; + .await)?; } }; @@ -365,7 +377,7 @@ pub(crate) async fn handle_client( .get(NeonOptions::PARAMS_COMPAT) .is_some(); - let mut node = connect_to_compute( + let res = connect_to_compute( ctx, &TcpMechanism { user_info: compute_user_info.clone(), @@ -377,22 +389,19 @@ pub(crate) async fn handle_client( config.wake_compute_retry_config, &config.connect_to_compute, ) - .or_else(|e| stream.throw_error(e, Some(ctx))) - .await?; + .await; + + let node = match res { + Ok(node) => node, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; let cancellation_handler_clone = Arc::clone(&cancellation_handler); let session = cancellation_handler_clone.get_key(); session.write_cancel_key(node.cancel_closure.clone())?; - - prepare_client_connection(&node, *session.key(), &mut stream).await?; - - // Before proxy passing, forward to compute whatever data is left in the - // PqStream input buffer. Normally there is none, but our serverless npm - // driver in pipeline mode sends startup, password and first query - // immediately after opening the connection. - let (stream, read_buf) = stream.into_inner(); - node.stream.write_all(&read_buf).await?; + prepare_client_connection(&node, *session.key(), &mut stream); + let stream = stream.flush_and_into_inner().await?; let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -413,31 +422,28 @@ pub(crate) async fn handle_client( } /// Finish client connection initialization: confirm auth success, send params, etc. -#[tracing::instrument(skip_all)] -pub(crate) async fn prepare_client_connection( +pub(crate) fn prepare_client_connection( node: &compute::PostgresConnection, cancel_key_data: CancelKeyData, stream: &mut PqStream, -) -> Result<(), std::io::Error> { +) { // Forward all deferred notices to the client. for notice in &node.delayed_notice { - stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + stream.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); } // Forward all postgres connection params to the client. for (name, value) in &node.params { - stream.write_message_noflush(&Be::ParameterStatus { + stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), - })?; + }); } - stream - .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? - .write_message(&Be::ReadyForQuery) - .await?; - - Ok(()) + stream.write_message(BeMessage::BackendKeyData(cancel_key_data)); + stream.write_message(BeMessage::ReadyForQuery); } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..55ab5f4dba 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 0879564ced..01e603ec14 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati #[cfg(test)] mod tests { - use super::ShouldRetryWakeCompute; use postgres_client::error::{DbError, SqlState}; + use super::ShouldRetryWakeCompute; + #[test] fn should_retry_wake_compute_for_db_error() { // These SQLStates should NOT trigger a wake_compute retry. diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 59c9ac27b8..c92ee49b8d 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use postgres_client::tls::TlsConnect; use postgres_protocol::message::frontend; -use tokio::io::{AsyncReadExt, DuplexStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream}; use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -49,15 +49,14 @@ async fn proxy_mitm( }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); - let (end_client, buf) = end_client.framed.into_inner(); - assert!(buf.is_empty()); + let end_client = end_client.flush_and_into_inner().await.unwrap(); let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame); // give the end_server the startup parameters let mut buf = BytesMut::new(); frontend::startup_message( &postgres_protocol::message::frontend::StartupMessageParams { - params: startup.params.into(), + params: startup.params.as_bytes().into(), }, &mut buf, ) diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index be6426a63c..61e8ee4a10 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,9 +26,7 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; @@ -128,7 +126,7 @@ trait TestAuth: Sized { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - stream.write_message_noflush(&Be::AuthenticationOk)?; + stream.write_message(BeMessage::AuthenticationOk); Ok(()) } } @@ -157,9 +155,7 @@ impl TestAuth for Scram { self, stream: &mut PqStream>, ) -> anyhow::Result<()> { - let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0, &RequestContext::test())) - .await? + let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test())) .authenticate() .await?; @@ -185,10 +181,12 @@ async fn dummy_proxy( auth.authenticate(&mut stream).await?; - stream - .write_message_noflush(&Be::CLIENT_ENCODING)? - .write_message(&Be::ReadyForQuery) - .await?; + stream.write_message(BeMessage::ParameterStatus { + name: b"client_encoding", + value: b"UTF8", + }); + stream.write_message(BeMessage::ReadyForQuery); + stream.flush().await?; Ok(()) } @@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 186fece4b2..6f56aeea06 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,10 +1,11 @@ use core::net::IpAddr; use std::sync::Arc; -use pq_proto::CancelKeyData; use tokio::sync::Mutex; use uuid::Uuid; +use crate::pqproto::CancelKeyData; + pub trait CancellationPublisherMut: Send + Sync + 'static { #[allow(async_fn_in_trait)] async fn try_publish( diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 7527bca6d0..3113bad949 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,16 +1,15 @@ use std::io::ErrorKind; use anyhow::Ok; -use pq_proto::{CancelKeyData, id_to_cancel_key}; -use serde::{Deserialize, Serialize}; + +use crate::pqproto::{CancelKeyData, id_to_cancel_key}; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum KeyPrefix { - #[serde(untagged)] Cancel(CancelKeyData), } @@ -18,9 +17,7 @@ impl KeyPrefix { pub(crate) fn build_redis_key(&self) -> String { match self { KeyPrefix::Cancel(key) => { - let hi = (key.backend_pid as u64) << 32; - let lo = (key.cancel_key as u64) & 0xffff_ffff; - let id = hi | lo; + let id = key.0.get(); let keyspace = keyspace::CANCEL_PREFIX; format!("{keyspace}:{id:x}") } @@ -63,10 +60,7 @@ mod tests { #[test] fn test_build_redis_key() { - let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }); + let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321)); let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); @@ -77,10 +71,7 @@ mod tests { let redis_key = "cancel:30390000d431"; let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - let ref_key = CancelKeyData { - backend_pid: 12345, - cancel_key: 54321, - }; + let ref_key = id_to_cancel_key(12345 << 32 | 54321); assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); let KeyPrefix::Cancel(cancel_key) = key; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 5f9f2509e2..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -2,11 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use futures::StreamExt; -use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; @@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate { role_name: RoleNameInt, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct CancelSession { - pub(crate) region_id: Option, - pub(crate) cancel_key_data: CancelKeyData, - pub(crate) session_id: Uuid, - pub(crate) peer_addr: Option, -} - fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -243,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 7f2f3a761c..8d26a3f453 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -1,7 +1,5 @@ //! Definitions for SASL messages. -use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; - use crate::parse::split_cstr; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). @@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> { } } -/// A single SASL message. -/// This struct is deliberately decoupled from lower-level -/// [`BeAuthenticationSaslMessage`]. -#[derive(Debug)] -pub(super) enum ServerMessage { - /// We expect to see more steps. - Continue(T), - /// This is the final step. - Final(T), -} - -impl<'a> ServerMessage<&'a str> { - pub(super) fn to_reply(&self) -> BeMessage<'a> { - BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), - ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/sasl/mod.rs b/proxy/src/sasl/mod.rs index f0181b404f..007b62dfd2 100644 --- a/proxy/src/sasl/mod.rs +++ b/proxy/src/sasl/mod.rs @@ -14,7 +14,7 @@ use std::io; pub(crate) use channel_binding::ChannelBinding; pub(crate) use messages::FirstMessage; -pub(crate) use stream::{Outcome, SaslStream}; +pub(crate) use stream::{Outcome, authenticate}; use thiserror::Error; use crate::error::{ReportableError, UserFacingError}; @@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] pub(crate) enum Error { + #[error("Unsupported authentication method: {0}")] + BadAuthMethod(Box), + #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -54,6 +57,7 @@ impl UserFacingError for Error { impl ReportableError for Error { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { + Error::BadAuthMethod(_) => crate::error::ErrorKind::User, Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, Error::BadClientMessage(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 46e6a439e5..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -3,61 +3,12 @@ use std::io; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; -use super::Mechanism; -use super::messages::ServerMessage; +use super::{Mechanism, Step}; +use crate::context::RequestContext; +use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::stream::PqStream; -/// Abstracts away all peculiarities of the libpq's protocol. -pub(crate) struct SaslStream<'a, S> { - /// The underlying stream. - stream: &'a mut PqStream, - /// Current password message we received from client. - current: bytes::Bytes, - /// First SASL message produced by client. - first: Option<&'a str>, -} - -impl<'a, S> SaslStream<'a, S> { - pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { - Self { - stream, - current: bytes::Bytes::new(), - first: Some(first), - } - } -} - -impl SaslStream<'_, S> { - // Receive a new SASL message from the client. - async fn recv(&mut self) -> io::Result<&str> { - if let Some(first) = self.first.take() { - return Ok(first); - } - - self.current = self.stream.read_password_message().await?; - let s = std::str::from_utf8(&self.current) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; - - Ok(s) - } -} - -impl SaslStream<'_, S> { - // Send a SASL message to the client. - async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; - Ok(()) - } - - // Queue a SASL message for the client. - fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message_noflush(&msg.to_reply())?; - Ok(()) - } -} - /// SASL authentication outcome. /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. @@ -69,33 +20,63 @@ pub(crate) enum Outcome { Failure(&'static str), } -impl SaslStream<'_, S> { - /// Perform SASL message exchange according to the underlying algorithm - /// until user is either authenticated or denied access. - pub(crate) async fn authenticate( - mut self, - mut mechanism: M, - ) -> super::Result> { - loop { - let input = self.recv().await?; - let step = mechanism.exchange(input).map_err(|error| { - info!(?error, "error during SASL exchange"); - error - })?; +pub async fn authenticate( + ctx: &RequestContext, + stream: &mut PqStream, + mechanism: F, +) -> super::Result> +where + S: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&str) -> super::Result, + M: Mechanism, +{ + let (mut mechanism, mut input) = { + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - use super::Step; - return Ok(match step { - Step::Continue(moved_mechanism, reply) => { - self.send(&ServerMessage::Continue(&reply)).await?; - mechanism = moved_mechanism; - continue; - } - Step::Success(result, reply) => { - self.send_noflush(&ServerMessage::Final(&reply))?; - Outcome::Success(result) - } - Step::Failure(reason) => Outcome::Failure(reason), - }); + // Initial client message contains the chosen auth method's name. + let msg = stream.read_password_message().await?; + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) + }; + + loop { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { + mechanism = moved_mechanism; + + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + drop(reply); + } + Ok(Step::Success(result, reply)) => { + // write reply + let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); + stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); + stream.write_message(BeMessage::AuthenticationOk); + + // exit with success + break Ok(Outcome::Success(result)); + } + // exit with failure + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..bf640c05e9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 1c5bb64480..eb80ac9ad0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::{ GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, }; -use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; use serde_json::value::RawValue; @@ -41,6 +40,7 @@ use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; +use crate::pqproto::StartupMessageParams; use crate::proxy::{NeonOptions, run_until_cancelled}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; @@ -219,7 +219,7 @@ fn get_conn_info( let mut options = Option::None; - let mut params = StartupMessageParamsBuilder::default(); + let mut params = StartupMessageParams::default(); params.insert("user", &username); params.insert("database", &dbname); for (key, value) in pairs { diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 360550b0ac..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -2,19 +2,17 @@ use std::pin::Pin; use std::sync::Arc; use std::{io, task}; -use bytes::BytesMut; -use pq_proto::framed::{ConnectionError, Framed}; -use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; use rustls::ServerConfig; -use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_rustls::server::TlsStream; -use tracing::debug; -use crate::control_plane::messages::ColdStartInfo; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; +use crate::pqproto::{ + BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf, + read_message, read_startup, +}; use crate::tls::TlsServerEndPoint; /// Stream wrapper which implements libpq's protocol. @@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint; /// or [`AsyncWrite`] to prevent subtle errors (e.g. trying /// to pass random malformed bytes through the connection). pub struct PqStream { - pub(crate) framed: Framed, + stream: S, + read: Vec, + write: WriteBuf, } impl PqStream { - /// Construct a new libpq protocol wrapper. - pub fn new(stream: S) -> Self { + pub fn get_ref(&self) -> &S { + &self.stream + } + + /// Construct a new libpq protocol wrapper over a stream without the first startup message. + #[cfg(test)] + pub fn new_skip_handshake(stream: S) -> Self { Self { - framed: Framed::new(stream), + stream, + read: Vec::new(), + write: WriteBuf::new(), } } - - /// Extract the underlying stream and read buffer. - pub fn into_inner(self) -> (S, BytesMut) { - self.framed.into_inner() - } - - /// Get a shared reference to the underlying stream. - pub(crate) fn get_ref(&self) -> &S { - self.framed.get_ref() - } } -fn err_connection() -> io::Error { - io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +impl PqStream { + /// Construct a new libpq protocol wrapper and read the first startup message. + /// + /// This is not cancel safe. + pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> { + let startup = read_startup(&mut stream).await?; + Ok(( + Self { + stream, + read: Vec::new(), + write: WriteBuf::new(), + }, + startup, + )) + } + + /// Tell the client that encryption is not supported. + /// + /// This is not cancel safe + pub async fn reject_encryption(&mut self) -> io::Result { + // N for No. + self.write.encryption(b'N'); + self.flush().await?; + read_startup(&mut self.stream).await + } } impl PqStream { - /// Receive [`FeStartupPacket`], which is a first packet sent by a client. - pub async fn read_startup_packet(&mut self) -> io::Result { - self.framed - .read_startup_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - async fn read_message(&mut self) -> io::Result { - self.framed - .read_message() - .await - .map_err(ConnectionError::into_io_error)? - .ok_or_else(err_connection) - } - - pub(crate) async fn read_password_message(&mut self) -> io::Result { - match self.read_message().await? { - FeMessage::PasswordMessage(msg) => Ok(msg), - bad => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unexpected message type: {bad:?}"), - )), + /// Read a raw postgres packet, which will respect the max length requested. + /// This is not cancel safe. + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { + let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; + if actual_tag != tag { + return Err(io::Error::other(format!( + "incorrect message tag, expected {:?}, got {:?}", + tag as char, actual_tag as char, + ))); } + Ok(msg) + } + + /// Read a postgres password message, which will respect the max length requested. + /// This is not cancel safe. + pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> { + // passwords are usually pretty short + // and SASL SCRAM messages are no longer than 256 bytes in my testing + // (a few hashes and random bytes, encoded into base64). + const MAX_PASSWORD_LENGTH: u32 = 512; + self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) + .await } } @@ -84,6 +101,16 @@ pub struct ReportedError { error_kind: ErrorKind, } +impl ReportedError { + pub fn new(e: (impl UserFacingError + Into)) -> Self { + let error_kind = e.get_error_kind(); + Self { + source: e.into(), + error_kind, + } + } +} + impl std::fmt::Display for ReportedError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.source.fmt(f) @@ -102,109 +129,65 @@ impl ReportableError for ReportedError { } } -#[derive(Serialize, Deserialize, Debug)] -enum ErrorTag { - #[serde(rename = "proxy")] - Proxy, - #[serde(rename = "compute")] - Compute, - #[serde(rename = "client")] - Client, - #[serde(rename = "controlplane")] - ControlPlane, - #[serde(rename = "other")] - Other, -} - -impl From for ErrorTag { - fn from(error_kind: ErrorKind) -> Self { - match error_kind { - ErrorKind::User => Self::Client, - ErrorKind::ClientDisconnect => Self::Client, - ErrorKind::RateLimit => Self::Proxy, - ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI - ErrorKind::Quota => Self::Proxy, - ErrorKind::Service => Self::Proxy, - ErrorKind::ControlPlane => Self::ControlPlane, - ErrorKind::Postgres => Self::Other, - ErrorKind::Compute => Self::Compute, - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -struct ProbeErrorData { - tag: ErrorTag, - msg: String, - cold_start_info: Option, -} - impl PqStream { - /// Write the message into an internal buffer, but don't flush the underlying stream. - pub(crate) fn write_message_noflush( - &mut self, - message: &BeMessage<'_>, - ) -> io::Result<&mut Self> { - self.framed - .write_message(message) - .map_err(ProtocolError::into_io_error)?; - Ok(self) + /// Tell the client that we are willing to accept SSL. + /// This is not cancel safe + pub async fn accept_tls(mut self) -> io::Result { + // S for SSL. + self.write.encryption(b'S'); + self.flush().await?; + Ok(self.stream) } - /// Write the message into an internal buffer and flush it. - pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { - self.write_message_noflush(message)?; - self.flush().await?; - Ok(self) + /// Assert that we are using direct TLS. + pub fn accept_direct_tls(self) -> S { + self.stream + } + + /// Write a raw message to the internal buffer. + pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec)) { + self.write.write_raw(size_hint, tag, f); + } + + /// Write the message into an internal buffer + pub fn write_message(&mut self, message: BeMessage<'_>) { + message.write_message(&mut self.write); } /// Flush the output buffer into the underlying stream. - pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { - self.framed.flush().await?; - Ok(self) + /// + /// This is cancel safe. + pub async fn flush(&mut self) -> io::Result<()> { + self.stream.write_all_buf(&mut self.write).await?; + self.write.reset(); + + self.stream.flush().await?; + + Ok(()) } - /// Writes message with the given error kind to the stream. - /// Used only for probe queries - async fn write_format_message( - &mut self, - msg: &str, - error_kind: ErrorKind, - ctx: Option<&crate::context::RequestContext>, - ) -> String { - let formatted_msg = match ctx { - Some(ctx) if ctx.get_testodrome_id().is_some() => { - serde_json::to_string(&ProbeErrorData { - tag: ErrorTag::from(error_kind), - msg: msg.to_string(), - cold_start_info: Some(ctx.cold_start_info()), - }) - .unwrap_or_default() - } - _ => msg.to_string(), - }; - - // already error case, ignore client IO error - self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None)) - .await - .inspect_err(|e| debug!("write_message failed: {e}")) - .ok(); - - formatted_msg + /// Flush the output buffer into the underlying stream. + /// + /// This is cancel safe. + pub async fn flush_and_into_inner(mut self) -> io::Result { + self.flush().await?; + Ok(self.stream) } - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Allowing string literals is safe under the assumption they might not contain any runtime info. - /// This method exists due to `&str` not implementing `Into`. + /// Write the error message to the client, then re-throw it. + /// + /// Trait [`UserFacingError`] acts as an allowlist for error types. /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub async fn throw_error_str( + pub(crate) async fn throw_error( &mut self, - msg: &'static str, - error_kind: ErrorKind, + error: E, ctx: Option<&crate::context::RequestContext>, - ) -> Result { - self.write_format_message(msg, error_kind, ctx).await; + ) -> ReportedError + where + E: UserFacingError + Into, + { + let error_kind = error.get_error_kind(); + let msg = error.to_string_client(); if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { tracing::info!( @@ -214,39 +197,39 @@ impl PqStream { ); } - Err(ReportedError { - source: anyhow::anyhow!(msg), - error_kind, - }) - } - - /// Write the error message using [`Self::write_format_message`], then re-throw it. - /// Trait [`UserFacingError`] acts as an allowlist for error types. - /// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind. - pub(crate) async fn throw_error( - &mut self, - error: E, - ctx: Option<&crate::context::RequestContext>, - ) -> Result - where - E: UserFacingError + Into, - { - let error_kind = error.get_error_kind(); - let msg = error.to_string_client(); - self.write_format_message(&msg, error_kind, ctx).await; - if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User { - tracing::info!( - kind=error_kind.to_metric_label(), - error=%error, - msg, - "forwarding error to user", - ); + let probe_msg; + let mut msg = &*msg; + if let Some(ctx) = ctx { + if ctx.get_testodrome_id().is_some() { + let tag = match error_kind { + ErrorKind::User => "client", + ErrorKind::ClientDisconnect => "client", + ErrorKind::RateLimit => "proxy", + ErrorKind::ServiceRateLimit => "proxy", + ErrorKind::Quota => "proxy", + ErrorKind::Service => "proxy", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "other", + ErrorKind::Compute => "compute", + }; + probe_msg = typed_json::json!({ + "tag": tag, + "msg": msg, + "cold_start_info": ctx.cold_start_info(), + }) + .to_string(); + msg = &probe_msg; + } } - Err(ReportedError { - source: anyhow::anyhow!(error), - error_kind, - }) + // TODO: either preserve the error code from postgres, or assign error codes to proxy errors. + self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR); + + self.flush() + .await + .unwrap_or_else(|e| tracing::debug!("write_message failed: {e}")); + + ReportedError::new(error) } } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index af33bcbd20..e3f7d88f7c 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -44,6 +44,7 @@ struct GlobalTimelinesState { // on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as // this map is dropped on restart. tombstones: HashMap, + tenant_tombstones: HashMap, conf: Arc, broker_active_set: Arc, @@ -81,10 +82,25 @@ impl GlobalTimelinesState { } } + fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool { + self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id) + } + + /// Removes all blocking tombstones for the given timeline ID. + /// Returns `true` if there have been actual changes. + fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool { + self.tombstones.remove(ttid).is_some() + || self.tenant_tombstones.remove(&ttid.tenant_id).is_some() + } + fn delete(&mut self, ttid: TenantTimelineId) { self.timelines.remove(&ttid); self.tombstones.insert(ttid, Instant::now()); } + + fn add_tenant_tombstone(&mut self, tenant_id: TenantId) { + self.tenant_tombstones.insert(tenant_id, Instant::now()); + } } /// A struct used to manage access to the global timelines map. @@ -99,6 +115,7 @@ impl GlobalTimelines { state: Mutex::new(GlobalTimelinesState { timelines: HashMap::new(), tombstones: HashMap::new(), + tenant_tombstones: HashMap::new(), conf, broker_active_set: Arc::new(TimelinesSet::default()), global_rate_limiter: RateLimiter::new(1, 1), @@ -245,7 +262,7 @@ impl GlobalTimelines { return Ok(timeline); } - if state.tombstones.contains_key(&ttid) { + if state.has_tombstone(&ttid) { anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate"); } @@ -295,13 +312,14 @@ impl GlobalTimelines { _ => {} } if check_tombstone { - if state.tombstones.contains_key(&ttid) { + if state.has_tombstone(&ttid) { anyhow::bail!("timeline {ttid} is deleted, refusing to recreate"); } } else { // We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust // that the human doing this manual intervention knows what they are doing, and remove its tombstone. - if state.tombstones.remove(&ttid).is_some() { + // It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed. + if state.remove_tombstone(&ttid) { warn!("un-deleted timeline {ttid}"); } } @@ -482,6 +500,7 @@ impl GlobalTimelines { let tli_res = { let state = self.state.lock().unwrap(); + // Do NOT check tenant tombstones here: those were set earlier if state.tombstones.contains_key(ttid) { // Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do. info!("Timeline {ttid} was already deleted"); @@ -557,6 +576,10 @@ impl GlobalTimelines { action: DeleteOrExclude, ) -> Result> { info!("deleting all timelines for tenant {}", tenant_id); + + // Adding a tombstone before getting the timelines to prevent new timeline additions + self.state.lock().unwrap().add_tenant_tombstone(*tenant_id); + let to_delete = self.get_all_for_tenant(*tenant_id); let mut err = None; @@ -600,6 +623,9 @@ impl GlobalTimelines { state .tombstones .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); + state + .tenant_tombstones + .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); } } diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index f314143952..70fecfbe22 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -87,6 +87,7 @@ impl WalProposer { let config = Config { ttid, safekeepers_list: addrs, + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 5000, sync_safekeepers, diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 02c02c0e7f..2b1c0db12f 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -482,6 +482,10 @@ async fn handle_tenant_timeline_delete( ForwardOutcome::NotForwarded(_req) => {} }; + service + .maybe_delete_timeline_import(tenant_id, timeline_id) + .await?; + // For timeline deletions, which both implement an "initially return 202, then 404 once // we're done" semantic, we wrap with a retry loop to expose a simpler API upstream. async fn deletion_wrapper(service: Arc, f: F) -> Result, ApiError> diff --git a/storage_controller/src/metrics.rs b/storage_controller/src/metrics.rs index 5ce2fb65e4..ccdbcad139 100644 --- a/storage_controller/src/metrics.rs +++ b/storage_controller/src/metrics.rs @@ -139,6 +139,14 @@ pub(crate) struct StorageControllerMetricGroup { /// HTTP request status counters for handled requests pub(crate) storage_controller_reconcile_long_running: measured::CounterVec, + + /// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles. + pub(crate) storage_controller_safkeeper_reconciles_queued: + measured::GaugeVec, + + /// Indicator of completed safekeeper reconciles, broken down by safekeeper. + pub(crate) storage_controller_safkeeper_reconciles_complete: + measured::CounterVec, } impl StorageControllerMetrics { @@ -257,6 +265,17 @@ pub(crate) enum Method { Other, } +#[derive(measured::LabelGroup, Clone)] +#[label(set = SafekeeperReconcilerLabelGroupSet)] +pub(crate) struct SafekeeperReconcilerLabelGroup<'a> { + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_az: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_node_id: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) sk_hostname: &'a str, +} + impl From for Method { fn from(value: hyper::Method) -> Self { if value == hyper::Method::GET { diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 7e4bb627af..790797bae2 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -99,8 +99,8 @@ use crate::tenant_shard::{ ScheduleOptimization, ScheduleOptimizationAction, TenantShard, }; use crate::timeline_import::{ - ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError, - TimelineImportState, UpcallClient, + FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport, + TimelineImportFinalizeError, TimelineImportState, UpcallClient, }; const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500); @@ -232,6 +232,9 @@ struct ServiceState { /// Queue of tenants who are waiting for concurrency limits to permit them to reconcile delayed_reconcile_rx: tokio::sync::mpsc::Receiver, + + /// Tracks ongoing timeline import finalization tasks + imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>, } /// Transform an error from a pageserver into an error to return to callers of a storage @@ -308,6 +311,7 @@ impl ServiceState { scheduler, ongoing_operation: None, delayed_reconcile_rx, + imports_finalizing: Default::default(), } } @@ -3823,6 +3827,13 @@ impl Service { .await; failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock"); let is_import = create_req.is_import(); + let read_only = matches!( + create_req.mode, + models::TimelineCreateRequestMode::Branch { + read_only: true, + .. + } + ); if is_import { // Ensure that there is no split on-going. @@ -3895,13 +3906,13 @@ impl Service { } None - } else if safekeepers { + } else if safekeepers || read_only { // Note that for imported timelines, we do not create the timeline on the safekeepers // straight away. Instead, we do it once the import finalized such that we know what // start LSN to provide for the safekeepers. This is done in // [`Self::finalize_timeline_import`]. let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id)) .await?; Some(res) @@ -3915,6 +3926,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress( self: &Arc, req: TimelineImportStatusRequest, @@ -3964,6 +3980,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress_upcall( self: &Arc, req: PutTimelineImportStatusRequest, @@ -4080,13 +4101,58 @@ impl Service { /// /// If this method gets pre-empted by shut down, it will be called again at start-up (on-going /// imports are stored in the database). + /// + /// # Cancel-Safety + /// Not cancel safe. + /// If the caller stops polling, the import will not be removed from + /// [`ServiceState::imports_finalizing`]. #[instrument(skip_all, fields( tenant_id=%import.tenant_id, timeline_id=%import.timeline_id, ))] + async fn finalize_timeline_import( self: &Arc, import: TimelineImport, + ) -> Result<(), TimelineImportFinalizeError> { + let tenant_timeline = (import.tenant_id, import.timeline_id); + + let (_finalize_import_guard, cancel) = { + let mut locked = self.inner.write().unwrap(); + let gate = Gate::default(); + let cancel = CancellationToken::default(); + + let guard = gate.enter().unwrap(); + + locked.imports_finalizing.insert( + tenant_timeline, + FinalizingImport { + gate, + cancel: cancel.clone(), + }, + ); + + (guard, cancel) + }; + + let res = tokio::select! { + res = self.finalize_timeline_import_impl(import) => { + res + }, + _ = cancel.cancelled() => { + Err(TimelineImportFinalizeError::Cancelled) + } + }; + + let mut locked = self.inner.write().unwrap(); + locked.imports_finalizing.remove(&tenant_timeline); + + res + } + + async fn finalize_timeline_import_impl( + self: &Arc, + import: TimelineImport, ) -> Result<(), TimelineImportFinalizeError> { tracing::info!("Finalizing timeline import"); @@ -4286,6 +4352,46 @@ impl Service { .await; } + /// Delete a timeline import if it exists + /// + /// Firstly, delete the entry from the database. Any updates + /// from pageservers after the update will fail with a 404, so the + /// import cannot progress into finalizing state if it's not there already. + /// Secondly, cancel the finalization if one is in progress. + pub(crate) async fn maybe_delete_timeline_import( + self: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + ) -> Result<(), DatabaseError> { + let tenant_has_ongoing_import = { + let locked = self.inner.read().unwrap(); + locked + .tenants + .range(TenantShardId::tenant_range(tenant_id)) + .any(|(_tid, shard)| shard.importing == TimelineImportState::Importing) + }; + + if !tenant_has_ongoing_import { + return Ok(()); + } + + self.persistence + .delete_timeline_import(tenant_id, timeline_id) + .await?; + + let maybe_finalizing = { + let mut locked = self.inner.write().unwrap(); + locked.imports_finalizing.remove(&(tenant_id, timeline_id)) + }; + + if let Some(finalizing) = maybe_finalizing { + finalizing.cancel.cancel(); + finalizing.gate.close().await; + } + + Ok(()) + } + pub(crate) async fn tenant_timeline_archival_config( &self, tenant_id: TenantId, @@ -8521,8 +8627,9 @@ impl Service { Some(ShardCount(new_shard_count)) } - /// Fetches the top tenant shards from every node, in descending order of - /// max logical size. Any node errors will be logged and ignored. + /// Fetches the top tenant shards from every available node, in descending order of + /// max logical size. Offline nodes are skipped, and any errors from available nodes + /// will be logged and ignored. async fn get_top_tenant_shards( &self, request: &TopTenantShardsRequest, @@ -8533,6 +8640,7 @@ impl Service { .unwrap() .nodes .values() + .filter(|node| node.is_available()) .cloned() .collect_vec(); diff --git a/storage_controller/src/service/safekeeper_reconciler.rs b/storage_controller/src/service/safekeeper_reconciler.rs index f756d98c64..fbf0b5c4e3 100644 --- a/storage_controller/src/service/safekeeper_reconciler.rs +++ b/storage_controller/src/service/safekeeper_reconciler.rs @@ -20,7 +20,9 @@ use utils::{ }; use crate::{ - persistence::SafekeeperTimelineOpKind, safekeeper::Safekeeper, + metrics::{METRICS_REGISTRY, SafekeeperReconcilerLabelGroup}, + persistence::SafekeeperTimelineOpKind, + safekeeper::Safekeeper, safekeeper_client::SafekeeperClient, }; @@ -218,7 +220,26 @@ impl ReconcilerHandle { fn schedule_reconcile(&self, req: ScheduleRequest) { let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id); let hostname = req.safekeeper.skp.host.clone(); + let sk_az = req.safekeeper.skp.availability_zone_id.clone(); + let sk_node_id = req.safekeeper.get_id().to_string(); + + // We don't have direct access to the queue depth here, so increase it blindly by 1. + // We know that putting into the queue increases the queue depth. The receiver will + // update with the correct value once it processes the next item. To avoid races where we + // reduce before we increase, leaving the gauge with a 1 value for a long time, we + // increase it before putting into the queue. + let queued_gauge = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_queued; + let label_group = SafekeeperReconcilerLabelGroup { + sk_az: &sk_az, + sk_node_id: &sk_node_id, + sk_hostname: &hostname, + }; + queued_gauge.inc(label_group.clone()); + if let Err(err) = self.tx.send((req, cancel, token_id)) { + queued_gauge.set(label_group, 0); tracing::info!("scheduling request onto {hostname} returned error: {err}"); } } @@ -283,6 +304,18 @@ impl SafekeeperReconciler { continue; } + let queued_gauge = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_queued; + queued_gauge.set( + SafekeeperReconcilerLabelGroup { + sk_az: &req.safekeeper.skp.availability_zone_id, + sk_node_id: &req.safekeeper.get_id().to_string(), + sk_hostname: &req.safekeeper.skp.host, + }, + self.rx.len() as i64, + ); + tokio::task::spawn(async move { let kind = req.kind; let tenant_id = req.tenant_id; @@ -511,6 +544,16 @@ impl SafekeeperReconcilerInner { req.generation, ) .await; + + let complete_counter = &METRICS_REGISTRY + .metrics_group + .storage_controller_safkeeper_reconciles_complete; + complete_counter.inc(SafekeeperReconcilerLabelGroup { + sk_az: &req.safekeeper.skp.availability_zone_id, + sk_node_id: &req.safekeeper.get_id().to_string(), + sk_hostname: &req.safekeeper.skp.host, + }); + if let Err(err) = res { tracing::info!( "couldn't remove reconciliation request onto {} from persistence: {err:?}", diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index cd5ace449d..1f673fe445 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -208,6 +208,7 @@ impl Service { self: &Arc, tenant_id: TenantId, timeline_info: &TimelineInfo, + read_only: bool, ) -> Result { let timeline_id = timeline_info.timeline_id; let pg_version = timeline_info.pg_version * 10000; @@ -220,7 +221,11 @@ impl Service { let start_lsn = timeline_info.last_record_lsn; // Choose initial set of safekeepers respecting affinity - let sks = self.safekeepers_for_new_timeline().await?; + let sks = if !read_only { + self.safekeepers_for_new_timeline().await? + } else { + Vec::new() + }; let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::>(); // Add timeline to db let mut timeline_persist = TimelinePersistence { @@ -253,6 +258,16 @@ impl Service { ))); } } + let ret = SafekeepersInfo { + generation: timeline_persist.generation as u32, + safekeepers: sks.clone(), + tenant_id, + timeline_id, + }; + if read_only { + return Ok(ret); + } + // Create the timeline on a quorum of safekeepers let remaining = self .tenant_timeline_create_safekeepers_quorum( @@ -316,12 +331,7 @@ impl Service { } } - Ok(SafekeepersInfo { - generation: timeline_persist.generation as u32, - safekeepers: sks, - tenant_id, - timeline_id, - }) + Ok(ret) } pub(crate) async fn tenant_timeline_create_safekeepers_until_success( @@ -336,8 +346,10 @@ impl Service { return Err(TimelineImportFinalizeError::ShuttingDown); } + // This function is only used in non-read-only scenarios + let read_only = false; let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .await; match res { @@ -410,6 +422,18 @@ impl Service { .chain(tl.sk_set.iter()) .collect::>(); + // The timeline has no safekeepers: we need to delete it from the db manually, + // as no safekeeper reconciler will get to it + if all_sks.is_empty() { + if let Err(err) = self + .persistence + .delete_timeline(tenant_id, timeline_id) + .await + { + tracing::warn!(%tenant_id, %timeline_id, "couldn't delete timeline from db: {err}"); + } + } + // Schedule reconciliations for &sk_id in all_sks.iter() { let pending_op = TimelinePendingOpPersistence { diff --git a/storage_controller/src/timeline_import.rs b/storage_controller/src/timeline_import.rs index 909e8e2899..eb50819d02 100644 --- a/storage_controller/src/timeline_import.rs +++ b/storage_controller/src/timeline_import.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use pageserver_api::models::{ShardImportProgress, ShardImportStatus}; use tokio_util::sync::CancellationToken; +use utils::sync::gate::Gate; use utils::{ id::{TenantId, TimelineId}, shard::ShardIndex, @@ -55,6 +56,8 @@ pub(crate) enum TimelineImportUpdateFollowUp { pub(crate) enum TimelineImportFinalizeError { #[error("Shut down interrupted import finalize")] ShuttingDown, + #[error("Import finalization was cancelled")] + Cancelled, #[error("Mismatched shard detected during import finalize: {0}")] MismatchedShards(ShardIndex), } @@ -164,6 +167,11 @@ impl TimelineImport { } } +pub(crate) struct FinalizingImport { + pub(crate) gate: Gate, + pub(crate) cancel: CancellationToken, +} + pub(crate) type ImportResult = Result<(), String>; pub(crate) struct UpcallClient { diff --git a/test_runner/fixtures/fast_import.py b/test_runner/fixtures/fast_import.py index f9e5f9c1db..bd6dc2583b 100644 --- a/test_runner/fixtures/fast_import.py +++ b/test_runner/fixtures/fast_import.py @@ -1,3 +1,4 @@ +import json import os import shutil import subprocess @@ -11,6 +12,7 @@ from _pytest.config import Config from fixtures.log_helper import log from fixtures.neon_cli import AbstractNeonCli +from fixtures.neon_fixtures import Endpoint, VanillaPostgres from fixtures.pg_version import PgVersion from fixtures.remote_storage import MockS3Server @@ -161,3 +163,57 @@ def fast_import( f.write(fi.cmd.stderr) log.info("Written logs to %s", test_output_dir) + + +def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path): + """ + Mock the import S3 bucket into a local directory for a provided vanilla PG instance. + """ + assert not vanilla_pg.is_running() + + path.mkdir() + # what cplane writes before scheduling fast_import + specpath = path / "spec.json" + specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"})) + # what fast_import writes + vanilla_pg.pgdatadir.rename(path / "pgdata") + statusdir = path / "status" + statusdir.mkdir() + (statusdir / "pgdata").write_text(json.dumps({"done": True})) + (statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True})) + + +def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int: + assert vanilla_pg.is_running() + + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + # fillfactor so we don't need to produce that much data + # 900 byte per row is > 10% => 1 row per page + vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") + + nrows = 0 + while True: + relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") + log.info( + f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages" + ) + if relblock_size >= target_relblock_size: + break + addrows = int((target_relblock_size - relblock_size) // 8192) + assert addrows >= 1, "forward progress" + vanilla_pg.safe_psql( + f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" + ) + nrows += addrows + + return nrows + + +def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int): + assert endpoint.safe_psql_many( + [ + "set effective_io_concurrency=32;", + "SET statement_timeout='300s';", + "select count(*), sum(data::bigint)::bigint from t", + ] + ) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]] diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 5c92f2e2d0..ab4885ce6b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -404,6 +404,29 @@ class PageserverTracingConfig: return ("tracing", value) +@dataclass +class PageserverImportConfig: + import_job_concurrency: int + import_job_soft_size_limit: int + import_job_checkpoint_threshold: int + + @staticmethod + def default() -> PageserverImportConfig: + return PageserverImportConfig( + import_job_concurrency=4, + import_job_soft_size_limit=512 * 1024, + import_job_checkpoint_threshold=4, + ) + + def to_config_key_value(self) -> tuple[str, dict[str, Any]]: + value = { + "import_job_concurrency": self.import_job_concurrency, + "import_job_soft_size_limit": self.import_job_soft_size_limit, + "import_job_checkpoint_threshold": self.import_job_checkpoint_threshold, + } + return ("timeline_import_config", value) + + class NeonEnvBuilder: """ Builder object to create a Neon runtime environment @@ -454,6 +477,7 @@ class NeonEnvBuilder: pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, + pageserver_import_config: PageserverImportConfig | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -511,6 +535,7 @@ class NeonEnvBuilder: ) self.pageserver_tracing_config = pageserver_tracing_config + self.pageserver_import_config = pageserver_import_config self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = ( pageserver_default_tenant_config_compaction_algorithm @@ -1179,6 +1204,10 @@ class NeonEnv: self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io self.pageserver_tracing_config = config.pageserver_tracing_config + if config.pageserver_import_config is None: + self.pageserver_import_config = PageserverImportConfig.default() + else: + self.pageserver_import_config = config.pageserver_import_config # Create the neon_local's `NeonLocalInitConf` cfg: dict[str, Any] = { @@ -1224,6 +1253,7 @@ class NeonEnv: # Create config for pageserver http_auth_type = "NeonJWT" if config.auth_enabled else "Trust" pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust" + grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust" for ps_id in range( self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers ): @@ -1250,18 +1280,13 @@ class NeonEnv: else None, "pg_auth_type": pg_auth_type, "http_auth_type": http_auth_type, + "grpc_auth_type": grpc_auth_type, "availability_zone": availability_zone, # Disable pageserver disk syncs in tests: when running tests concurrently, this avoids # the pageserver taking a long time to start up due to syncfs flushing other tests' data "no_sync": True, # Look for gaps in WAL received from safekeepeers "validate_wal_contiguity": True, - # TODO(vlad): make these configurable through the builder - "timeline_import_config": { - "import_job_concurrency": 4, - "import_job_soft_size_limit": 512 * 1024, - "import_job_checkpoint_threshold": 4, - }, } # Batching (https://github.com/neondatabase/neon/issues/9377): @@ -1323,6 +1348,12 @@ class NeonEnv: ps_cfg[key] = value + if self.pageserver_import_config is not None: + key, value = self.pageserver_import_config.to_config_key_value() + + if key not in ps_cfg: + ps_cfg[key] = value + # Create a corresponding NeonPageserver object ps = NeonPageserver( self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"] @@ -2306,6 +2337,22 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) + def import_status( + self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int + ): + payload = { + "tenant_shard_id": str(tenant_shard_id), + "timeline_id": str(timeline_id), + "generation": generation, + } + + self.request( + "GET", + f"{self.api}/upcall/v1/timeline_import_status", + headers=self.headers(TokenScope.GENERATIONS_API), + json=payload, + ) + def reconcile_all(self): r = self.request( "POST", @@ -2782,6 +2829,11 @@ class NeonPageserver(PgProtocol, LogUtils): if self.running: self.http_client().configure_failpoints([(name, action)]) + def clear_persistent_failpoint(self, name: str): + del self._persistent_failpoints[name] + if self.running: + self.http_client().configure_failpoints([(name, "off")]) + def timeline_dir( self, tenant_shard_id: TenantId | TenantShardId, diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index c2d176bf5a..c29192c25c 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_delete( self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs - ): + ) -> int: """ Note that deletion is not instant, it is scheduled and performed mostly in the background. So if you need to wait for it to complete use `timeline_delete_wait_completed`. @@ -688,6 +688,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): res_json = res.json() assert res_json is None + return res.status_code + def timeline_gc( self, tenant_id: TenantId | TenantShardId, diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index b29610e021..1420dc59a1 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -1,31 +1,41 @@ from __future__ import annotations import enum +import json import time from collections import Counter from dataclasses import dataclass from enum import StrEnum +from threading import Event from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnv, NeonEnvBuilder, NeonPageserver, PgBin, + VanillaPostgres, wait_for_last_flush_lsn, ) +from fixtures.pageserver.http import ( + ImportPgdataIdemptencyKey, +) from fixtures.pageserver.utils import wait_for_upload_queue_empty from fixtures.remote_storage import RemoteStorageKind -from fixtures.utils import human_bytes, wait_until +from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until +from werkzeug.wrappers.response import Response if TYPE_CHECKING: from collections.abc import Iterable from typing import Any from fixtures.pageserver.http import PageserverHttpClient + from pytest_httpserver import HTTPServer + from werkzeug.wrappers.request import Request GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy" @@ -164,6 +174,7 @@ class EvictionEnv: min_avail_bytes, mock_behavior, eviction_order: EvictionOrder, + wait_logical_size: bool = True, ): """ Starts pageserver up with mocked statvfs setup. The startup is @@ -201,11 +212,12 @@ class EvictionEnv: pageserver.start() # we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction - for tenant_id, timeline_id in self.timelines: - tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id) - # Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test - if tenant_ps is not None: - tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id) + if wait_logical_size: + for tenant_id, timeline_id in self.timelines: + tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id) + # Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test + if tenant_ps is not None: + tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id) def statvfs_called(): pageserver.assert_log_contains(".*running mocked statvfs.*") @@ -882,3 +894,121 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv): assert total_size - post_eviction_total_size >= evict_bytes, ( "we requested at least evict_bytes worth of free space" ) + + +@run_only_on_default_postgres(reason="PG version is irrelevant here") +def test_import_timeline_disk_pressure_eviction( + neon_env_builder: NeonEnvBuilder, + vanilla_pg: VanillaPostgres, + make_httpserver: HTTPServer, + pg_bin: PgBin, +): + """ + TODO + """ + # Set up mock control plane HTTP server to listen for import completions + import_completion_signaled = Event() + + def handler(request: Request) -> Response: + log.info(f"control plane /import_complete request: {request.json}") + import_completion_signaled.set() + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request( + "/storage/api/v1/import_complete", method="PUT" + ).respond_with_handler(handler) + + # Plug the cplane mock in + neon_env_builder.control_plane_hooks_api = ( + f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/" + ) + + # The import will specifiy a local filesystem path mocking remote storage + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + + vanilla_pg.start() + target_relblock_size = 1024 * 1024 * 128 + populate_vanilla_pg(vanilla_pg, target_relblock_size) + vanilla_pg.stop() + + env = neon_env_builder.init_configs() + env.start() + + importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket" + mock_import_bucket(vanilla_pg, importbucket_path) + + tenant_id = TenantId.generate() + timeline_id = TimelineId.generate() + idempotency = ImportPgdataIdemptencyKey.random() + + eviction_env = EvictionEnv( + timelines=[(tenant_id, timeline_id)], + neon_env=env, + pageserver_http=env.pageserver.http_client(), + layer_size=5 * 1024 * 1024, # Doesn't apply here + pg_bin=pg_bin, # Not used here + pgbench_init_lsns={}, # Not used here + ) + + # Pause before delivering the final notification to storcon. + # This keeps the import in progress. + failpoint_name = "import-timeline-pre-success-notify-pausable" + env.pageserver.add_persistent_failpoint(failpoint_name, "pause") + + env.storage_controller.tenant_create(tenant_id) + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket_path.absolute())}}, + }, + }, + ) + + def hit_failpoint(): + log.info("Checking log for pattern...") + try: + assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*") + except Exception: + log.exception("Failed to find pattern in log") + raise + + wait_until(hit_failpoint) + assert not import_completion_signaled.is_set() + + env.pageserver.stop() + + total_size, _, _ = eviction_env.timelines_du(env.pageserver) + blocksize = 512 + total_blocks = (total_size + (blocksize - 1)) // blocksize + + eviction_env.pageserver_start_with_disk_usage_eviction( + env.pageserver, + period="1s", + max_usage_pct=33, + min_avail_bytes=0, + mock_behavior={ + "type": "Success", + "blocksize": blocksize, + "total_blocks": total_blocks, + # Only count layer files towards used bytes in the mock_statvfs. + # This avoids accounting for metadata files & tenant conf in the tests. + "name_filter": ".*__.*", + }, + eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE, + wait_logical_size=False, + ) + + wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved")) + + env.pageserver.clear_persistent_failpoint(failpoint_name) + + def cplane_notified(): + assert import_completion_signaled.is_set() + + wait_until(cplane_notified) + + env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*") diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 0472b92145..8d4f908cc0 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -1,7 +1,10 @@ import base64 +import concurrent.futures import json +import random +import threading import time -from enum import Enum +from enum import Enum, StrEnum from pathlib import Path from threading import Event @@ -9,9 +12,22 @@ import psycopg2 import psycopg2.errors import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId -from fixtures.fast_import import FastImport +from fixtures.fast_import import ( + FastImport, + mock_import_bucket, + populate_vanilla_pg, + validate_import_from_vanilla_pg, +) from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + PageserverImportConfig, + PgBin, + PgProtocol, + StorageControllerApiException, + StorageControllerMigrationConfig, + VanillaPostgres, +) from fixtures.pageserver.http import ( ImportPgdataIdemptencyKey, ) @@ -49,24 +65,6 @@ smoke_params = [ ] -def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path): - """ - Mock the import S3 bucket into a local directory for a provided vanilla PG instance. - """ - assert not vanilla_pg.is_running() - - path.mkdir() - # what cplane writes before scheduling fast_import - specpath = path / "spec.json" - specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"})) - # what fast_import writes - vanilla_pg.pgdatadir.rename(path / "pgdata") - statusdir = path / "status" - statusdir.mkdir() - (statusdir / "pgdata").write_text(json.dumps({"done": True})) - (statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True})) - - @skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data") @pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params) def test_pgdata_import_smoke( @@ -121,10 +119,6 @@ def test_pgdata_import_smoke( # Put data in vanilla pg # - vanilla_pg.start() - vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") - - log.info("create relblock data") if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE: target_relblock_size = stripe_size * 8192 elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD: @@ -135,45 +129,8 @@ def test_pgdata_import_smoke( else: raise ValueError - # fillfactor so we don't need to produce that much data - # 900 byte per row is > 10% => 1 row per page - vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") - - nrows = 0 - while True: - relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") - log.info( - f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages" - ) - if relblock_size >= target_relblock_size: - break - addrows = int((target_relblock_size - relblock_size) // 8192) - assert addrows >= 1, "forward progress" - vanilla_pg.safe_psql( - f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" - ) - nrows += addrows - expect_nrows = nrows - expect_sum = ( - (nrows) * (nrows + 1) // 2 - ) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n - - def validate_vanilla_equivalence(ep): - # TODO: would be nicer to just compare pgdump - - # Enable IO concurrency for batching on large sequential scan, to avoid making - # this test unnecessarily onerous on CPU. Especially on debug mode, it's still - # pretty onerous though, so increase statement_timeout to avoid timeouts. - assert ep.safe_psql_many( - [ - "set effective_io_concurrency=32;", - "SET statement_timeout='300s';", - "select count(*), sum(data::bigint)::bigint from t", - ] - ) == [[], [], [(expect_nrows, expect_sum)]] - - validate_vanilla_equivalence(vanilla_pg) - + vanilla_pg.start() + rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size) vanilla_pg.stop() # @@ -264,14 +221,14 @@ def test_pgdata_import_smoke( config_lines=ep_config, ) - validate_vanilla_equivalence(ro_endpoint) + validate_import_from_vanilla_pg(ro_endpoint, rows_inserted) # ensure the import survives restarts ro_endpoint.stop() env.pageserver.stop(immediate=True) env.pageserver.start() ro_endpoint.start() - validate_vanilla_equivalence(ro_endpoint) + validate_import_from_vanilla_pg(ro_endpoint, rows_inserted) # # validate the layer files in each shard only have the shard-specific data @@ -311,7 +268,7 @@ def test_pgdata_import_smoke( child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip") child_workload.validate() - validate_vanilla_equivalence(child_workload.endpoint()) + validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted) # ... at the initdb lsn _ = env.create_branch( @@ -326,10 +283,21 @@ def test_pgdata_import_smoke( tenant_id=tenant_id, config_lines=ep_config, ) - validate_vanilla_equivalence(br_initdb_endpoint) + validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted) with pytest.raises(psycopg2.errors.UndefinedTable): br_initdb_endpoint.safe_psql(f"select * from {workload.table}") + # The storage controller might be overly eager and attempt to finalize + # the import before the task got a chance to exit. + env.storage_controller.allowed_errors.extend( + [ + ".*Call to node.*management API.*failed.*Import task still running.*", + ] + ) + + for ps in env.pageservers: + ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"]) + @run_only_on_default_postgres(reason="PG version is irrelevant here") def test_import_completion_on_restart( @@ -413,8 +381,12 @@ def test_import_completion_on_restart( @run_only_on_default_postgres(reason="PG version is irrelevant here") -def test_import_respects_tenant_shutdown( - neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer +@pytest.mark.parametrize("action", ["restart", "delete"]) +def test_import_respects_timeline_lifecycle( + neon_env_builder: NeonEnvBuilder, + vanilla_pg: VanillaPostgres, + make_httpserver: HTTPServer, + action: str, ): """ Validate that importing timelines respect the usual timeline life cycle: @@ -482,16 +454,276 @@ def test_import_respects_tenant_shutdown( wait_until(hit_failpoint) assert not import_completion_signaled.is_set() - # Restart the pageserver while an import job is in progress. - # This clears the failpoint and we expect that the import starts up afresh - # after the restart and eventually completes. - env.pageserver.stop() - env.pageserver.start() + if action == "restart": + # Restart the pageserver while an import job is in progress. + # This clears the failpoint and we expect that the import starts up afresh + # after the restart and eventually completes. + env.pageserver.stop() + env.pageserver.start() - def cplane_notified(): - assert import_completion_signaled.is_set() + def cplane_notified(): + assert import_completion_signaled.is_set() - wait_until(cplane_notified) + wait_until(cplane_notified) + elif action == "delete": + status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id) + assert status == 200 + + timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id) + assert not timeline_path.exists(), "Timeline dir exists after deletion" + + shard_zero = TenantShardId(tenant_id, 0, 0) + location = env.storage_controller.inspect(shard_zero) + assert location is not None + generation = location[0] + + with pytest.raises(StorageControllerApiException, match="not found"): + env.storage_controller.import_status(shard_zero, timeline_id, generation) + else: + raise RuntimeError(f"{action} param not recognized") + + # The storage controller might be overly eager and attempt to finalize + # the import before the task got a chance to exit. + env.storage_controller.allowed_errors.extend( + [ + ".*Call to node.*management API.*failed.*Import task still running.*", + ] + ) + + for ps in env.pageservers: + ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"]) + + +@skip_in_debug_build("Validation query takes too long in debug builds") +def test_import_chaos( + neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer +): + """ + Perform a timeline import while injecting chaos in the environment. + We expect that the import completes eventually, that it passes validation and + the resulting timeline can be written to. + """ + TARGET_RELBOCK_SIZE = 512 * 1024 * 1024 # 512 MiB + ALLOWED_IMPORT_RUNTIME = 90 # seconds + SHARD_COUNT = 4 + + neon_env_builder.num_pageservers = SHARD_COUNT + neon_env_builder.pageserver_import_config = PageserverImportConfig( + import_job_concurrency=1, + import_job_soft_size_limit=64 * 1024, + import_job_checkpoint_threshold=4, + ) + + # Set up mock control plane HTTP server to listen for import completions + import_completion_signaled = Event() + # There's some Python magic at play here. A list can be updated from the + # handler thread, but an optional cannot. Hence, use a list with one element. + import_error = [] + + def handler(request: Request) -> Response: + assert request.json is not None + + body = request.json + if "error" in body: + if body["error"]: + import_error.append(body["error"]) + + log.info(f"control plane /import_complete request: {request.json}") + import_completion_signaled.set() + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request( + "/storage/api/v1/import_complete", method="PUT" + ).respond_with_handler(handler) + + # Plug the cplane mock in + neon_env_builder.control_plane_hooks_api = ( + f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/" + ) + + # The import will specifiy a local filesystem path mocking remote storage + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + + vanilla_pg.start() + + inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE) + + vanilla_pg.stop() + + env = neon_env_builder.init_configs() + env.start() + + # Pause after every import task to extend the test runtime and allow + # for more chaos injection. + for ps in env.pageservers: + ps.add_persistent_failpoint("import-task-complete-pausable", "sleep(5)") + + env.storage_controller.allowed_errors.extend( + [ + # The shard might have moved or the pageserver hosting the shard restarted + ".*Call to node.*management API.*failed.*", + # Migrations have their time outs set to 0 + ".*Timed out after.*downloading layers.*", + ".*Failed to prepare by downloading layers.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + ] + ) + for ps in env.pageservers: + ps.allowed_errors.extend( + [ + # We might re-write a layer in a different generation if the import + # needs to redo some of the progress since not each job is checkpointed. + ".*was unlinked but was not dangling.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + # Test can SIGTERM pageserver while it is downloading + ".*removing local file.*temp_download.*", + ".*Failed to flush heatmap.*", + # Test can SIGTERM the storage controller while pageserver + # is attempting to upcall. + ".*storage controller upcall failed.*timeline_import_status.*", + # TODO(vlad): TenantManager::reset_tenant returns a blanked anyhow error. + # It should return ResourceUnavailable or something that doesn't error log. + ".*activate_post_import.*InternalServerError.*tenant map is shutting down.*", + # TODO(vlad): How can this happen? + ".*Failed to download a remote file: deserialize index part file.*", + ".*Cancelled request finished with an error.*", + ] + ) + + importbucket_path = neon_env_builder.repo_dir / "test_import_chaos_bucket" + mock_import_bucket(vanilla_pg, importbucket_path) + + tenant_id = TenantId.generate() + timeline_id = TimelineId.generate() + idempotency = ImportPgdataIdemptencyKey.random() + + env.storage_controller.tenant_create( + tenant_id, shard_count=SHARD_COUNT, placement_policy={"Attached": 1} + ) + env.storage_controller.reconcile_until_idle() + + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket_path.absolute())}}, + }, + }, + ) + + def chaos(stop_chaos: threading.Event): + class ChaosType(StrEnum): + MIGRATE_SHARD = "migrate_shard" + RESTART_IMMEDIATE = "restart_immediate" + RESTART = "restart" + STORCON_RESTART_IMMEDIATE = "storcon_restart_immediate" + + while not stop_chaos.is_set(): + chaos_type = random.choices( + population=[ + ChaosType.MIGRATE_SHARD, + ChaosType.RESTART, + ChaosType.RESTART_IMMEDIATE, + ChaosType.STORCON_RESTART_IMMEDIATE, + ], + weights=[0.25, 0.25, 0.25, 0.25], + k=1, + )[0] + + try: + if chaos_type == ChaosType.MIGRATE_SHARD: + target_shard_number = random.randint(0, SHARD_COUNT - 1) + target_shard = TenantShardId(tenant_id, target_shard_number, SHARD_COUNT) + + placements = env.storage_controller.get_tenants_placement() + log.info(f"{placements=}") + target_ps = placements[str(target_shard)]["intent"]["attached"] + if len(placements[str(target_shard)]["intent"]["secondary"]) == 0: + dest_ps = None + else: + dest_ps = placements[str(target_shard)]["intent"]["secondary"][0] + + if target_ps is None or dest_ps is None: + continue + + config = StorageControllerMigrationConfig( + secondary_warmup_timeout="0s", + secondary_download_request_timeout="0s", + prewarm=False, + ) + env.storage_controller.tenant_shard_migrate(target_shard, dest_ps, config) + + log.info( + f"CHAOS: Migrating shard {target_shard} from pageserver {target_ps} to {dest_ps}" + ) + elif chaos_type == ChaosType.RESTART_IMMEDIATE: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Immediate restart of pageserver {target_ps.id}") + target_ps.stop(immediate=True) + target_ps.start() + elif chaos_type == ChaosType.RESTART: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Normal restart of pageserver {target_ps.id}") + target_ps.stop(immediate=False) + target_ps.start() + elif chaos_type == ChaosType.STORCON_RESTART_IMMEDIATE: + log.info("CHAOS: Immediate restart of storage controller") + env.storage_controller.stop(immediate=True) + env.storage_controller.start() + except Exception as e: + log.warning(f"CHAOS: Error during chaos operation {chaos_type}: {e}") + + # Sleep before next chaos event + time.sleep(1) + + log.info("Chaos injector stopped") + + def wait_for_import_completion(): + start = time.time() + done = import_completion_signaled.wait(ALLOWED_IMPORT_RUNTIME) + if not done: + raise TimeoutError(f"Import did not signal completion within {ALLOWED_IMPORT_RUNTIME}") + + end = time.time() + + log.info(f"Import completion signalled after {end - start}s {import_error=}") + + if import_error: + raise RuntimeError(f"Import error: {import_error}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + stop_chaos = threading.Event() + + wait_for_import_completion_fut = executor.submit(wait_for_import_completion) + chaos_fut = executor.submit(chaos, stop_chaos) + + try: + wait_for_import_completion_fut.result() + except Exception as e: + raise e + finally: + stop_chaos.set() + chaos_fut.result() + + import_branch_name = "imported" + env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) + endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id) + + # Validate the imported data is legit + validate_import_from_vanilla_pg(endpoint, inserted_rows) + + endpoint.stop() + + # Validate writes + workload = Workload(env, tenant_id, timeline_id, branch_name=import_branch_name) + workload.init() + workload.write_rows(64) + workload.validate() def test_fast_import_with_pageserver_ingest( diff --git a/test_runner/regress/test_layers_from_future.py b/test_runner/regress/test_layers_from_future.py index b4eba2779d..f3fcdb0d14 100644 --- a/test_runner/regress/test_layers_from_future.py +++ b/test_runner/regress/test_layers_from_future.py @@ -20,6 +20,9 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import query_scalar, wait_until +@pytest.mark.skip( + reason="We won't create future layers any more after https://github.com/neondatabase/neon/pull/10548" +) @pytest.mark.parametrize( "attach_mode", ["default_generation", "same_generation"], diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index f2523ec9b5..8d18311f3d 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, ".*downloading failed, possibly for shutdown", # {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n' ".*page_service.*will not become active.*", + # the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state + ".*wal_connection_manager.*layer file download failed: No file found.*", + ".*wal_connection_manager.*could not ingest record.*", ] ) @@ -156,6 +159,45 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, env.pageservers[2].id: ("Detached", None), } + # Track all the attached locations with mode and generation + history: list[tuple[int, str, int | None]] = [] + + def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool: + # Rules for when a pageserver may read: + # - our generation is higher than any previous + # - our generation is equal to previous, but no other pageserver + # in that generation has been AttachedSingle (i.e. allowed to compact/GC) + # - our generation is equal to previous, and the previous holder of this + # generation was the same node as we're attaching now. + # + # If these conditions are not met, then a read _might_ work, but the pageserver might + # also hit errors trying to download layers. + highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None) + + if generation is None: + # We're not in an attached state, we may not read + return False + elif highest_historic_generation is not None and generation < highest_historic_generation: + # We are in an outdated generation, we may not read + return False + elif highest_historic_generation is not None and generation == highest_historic_generation: + # We are re-using a generation: if any pageserver other than this one + # has held AttachedSingle mode, this node may not read (because some other + # node may be doing GC/compaction). + if any( + i[1] == "AttachedSingle" + and i[2] == highest_historic_generation + and i[0] != pageserver.id + for i in history + ): + log.info( + f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}" + ) + return False + + # Fall through: we have passed conditions for readability + return True + latest_attached = env.pageservers[0].id for _i in range(0, 64): @@ -199,9 +241,10 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, assert len(tenants) == 1 assert tenants[0]["generation"] == new_generation - log.info("Entering postgres...") - workload.churn_rows(rng.randint(128, 256), pageserver.id) - workload.validate(pageserver.id) + if may_read(pageserver, last_state_ps[0], last_state_ps[1]): + log.info("Entering postgres...") + workload.churn_rows(rng.randint(128, 256), pageserver.id) + workload.validate(pageserver.id) elif last_state_ps[0].startswith("Attached"): # The `storage_controller` will only re-attach on startup when a pageserver was the # holder of the latest generation: otherwise the pageserver will revert to detached @@ -241,18 +284,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, location_conf["generation"] = generation pageserver.tenant_location_configure(tenant_id, location_conf) + last_state[pageserver.id] = (mode, generation) - # It's only valid to connect to the last generation. Newer generations may yank layer - # files used in older generations. - last_generation = max( - [s[1] for s in last_state.values() if s[1] is not None], default=None - ) + may_read_this_generation = may_read(pageserver, mode, generation) + history.append((pageserver.id, mode, generation)) - if mode.startswith("Attached") and generation == last_generation: - # This is a basic test: we are validating that he endpoint works properly _between_ - # configuration changes. A stronger test would be to validate that clients see - # no errors while we are making the changes. + # This is a basic test: we are validating that he endpoint works properly _between_ + # configuration changes. A stronger test would be to validate that clients see + # no errors while we are making the changes. + if may_read_this_generation: workload.churn_rows( rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale" ) @@ -265,9 +306,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, assert gc_summary["remote_storage_errors"] == 0 assert gc_summary["indices_deleted"] > 0 - # Attach all pageservers + # Attach all pageservers, in a higher generation than any previous. We will use the same + # gen for all, and AttachedMulti mode so that they do not interfere with one another. + generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id) for ps in env.pageservers: - location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}} + location_conf = { + "mode": "AttachedMulti", + "secondary_conf": None, + "tenant_conf": {}, + "generation": generation, + } ps.tenant_location_configure(tenant_id, location_conf) # Confirm that all are readable diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index af018f7b5d..346ef0951d 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -4158,17 +4158,12 @@ def test_storcon_create_delete_sk_down( env.storage_controller.stop() env.storage_controller.start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") @@ -4197,10 +4192,10 @@ def test_storcon_create_delete_sk_down( # ensure the safekeeper deleted the timeline def timeline_deleted_on_active_sks(): env.safekeepers[0].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) env.safekeepers[2].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) wait_until(timeline_deleted_on_active_sks) @@ -4215,7 +4210,7 @@ def test_storcon_create_delete_sk_down( # ensure that there is log msgs for the third safekeeper too def timeline_deleted_on_sk(): env.safekeepers[1].assert_log_contains( - f"deleting timeline {tenant_id}/{child_timeline_id} from disk" + f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)" ) wait_until(timeline_deleted_on_sk) @@ -4249,17 +4244,12 @@ def test_storcon_few_sk( env.safekeepers[0].assert_log_contains(f"creating new timeline {tenant_id}/{timeline_id}") - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index d42c5d403e..f0810270b1 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -10,6 +10,7 @@ from queue import Empty, Queue from threading import Barrier import pytest +import requests from fixtures.common_types import Lsn, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import ( @@ -401,8 +402,25 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots "earlier", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_pipe ) - snapshot_branchpoint_old = env.create_branch( - "snapshot_branchpoint_old", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_y + snapshot_branchpoint_old = TimelineId.generate() + + env.storage_controller.timeline_create( + env.initial_tenant, + { + "new_timeline_id": str(snapshot_branchpoint_old), + "ancestor_start_lsn": str(branchpoint_y), + "ancestor_timeline_id": str(env.initial_timeline), + "read_only": True, + }, + ) + sk = env.safekeepers[0] + assert sk + with pytest.raises(requests.exceptions.HTTPError, match="Not Found"): + sk.http_client().timeline_status( + tenant_id=env.initial_tenant, timeline_id=snapshot_branchpoint_old + ) + env.neon_cli.mappings_map_branch( + "snapshot_branchpoint_old", env.initial_tenant, snapshot_branchpoint_old ) snapshot_branchpoint = env.create_branch( diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index a9a6699e5c..6a7c7a8bef 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2012,10 +2012,7 @@ def test_explicit_timeline_creation(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant timeline_id = env.initial_timeline - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # expected to fail because timeline is not created on safekeepers with pytest.raises(Exception, match=r".*timed out.*"): @@ -2043,10 +2040,7 @@ def test_explicit_timeline_creation_storcon(neon_env_builder: NeonEnvBuilder): } env = neon_env_builder.init_start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index c5dd34f64f..4070f99568 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -637,10 +637,7 @@ async def quorum_sanity_single( # create timeline on `members_sks` Safekeeper.create_timeline(tenant_id, timeline_id, env.pageservers[0], mconf, members_sks) - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create(branch_name, config_lines=config_lines) + ep = env.endpoints.create(branch_name) ep.start(safekeeper_generation=1, safekeepers=compute_sks_ids) ep.safe_psql("create table t(key int, value text)") diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 9e1123ac0e..2b07889871 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -18,6 +18,8 @@ license.workspace = true ahash = { version = "0.8" } anstream = { version = "0.6" } anyhow = { version = "1", features = ["backtrace"] } +axum = { version = "0.8", features = ["ws"] } +axum-core = { version = "0.5", default-features = false, features = ["tracing"] } base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } base64-647d43efb71741da = { package = "base64", version = "0.21" } base64ct = { version = "1", default-features = false, features = ["std"] } @@ -39,10 +41,8 @@ env_logger = { version = "0.11" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } form_urlencoded = { version = "1" } futures-channel = { version = "0.3", features = ["sink"] } -futures-core = { version = "0.3" } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } -futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } @@ -52,7 +52,7 @@ hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper-582f2526e08bb6a0 = { package = "hyper", version = "0.14", features = ["client", "http1", "http2", "runtime", "server", "stream"] } hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["full"] } -hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2", "server", "service"] } +hyper-util = { version = "0.1", features = ["client-legacy", "server-auto", "service"] } indexmap = { version = "2", features = ["serde"] } itertools = { version = "0.12" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } @@ -72,7 +72,6 @@ num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } -percent-encoding = { version = "2" } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -98,7 +97,7 @@ tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unpref time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } -tokio-stream = { version = "0.1" } +tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_edit = { version = "0.22", features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["balance", "buffer", "limit", "log"] }