diff --git a/.config/hakari.toml b/.config/hakari.toml index 9913ecc9c0..b5990d090e 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -23,10 +23,30 @@ platforms = [ ] [final-excludes] -# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but -# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded -# from depending on workspace-hack because most of the dependencies are not used. -workspace-members = ["vm_monitor"] +workspace-members = [ + # vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but + # it is built primarly in separate repo neondatabase/autoscaling and thus is excluded + # from depending on workspace-hack because most of the dependencies are not used. + "vm_monitor", + # All of these exist in libs and are not usually built independently. + # Putting workspace hack there adds a bottleneck for cargo builds. + "compute_api", + "consumption_metrics", + "desim", + "metrics", + "pageserver_api", + "postgres_backend", + "postgres_connection", + "postgres_ffi", + "pq_proto", + "remote_storage", + "safekeeper_api", + "tenant_size_model", + "tracing-utils", + "utils", + "wal_craft", + "walproposer", +] # Write out exact versions rather than a semver range. (Defaults to false.) # exact-versions = true diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 4ccf190c6a..6c2cee0971 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -43,7 +43,7 @@ inputs: pg_version: description: 'Postgres version to use for tests' required: false - default: 'v14' + default: 'v16' benchmark_durations: description: 'benchmark durations JSON' required: false @@ -169,10 +169,8 @@ runs: EXTRA_PARAMS="--durations-path $TEST_OUTPUT/benchmark_durations.json $EXTRA_PARAMS" fi - if [[ "${{ inputs.build_type }}" == "debug" ]]; then + if [[ $BUILD_TYPE == "debug" && $RUNNER_ARCH == 'X64' ]]; then cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run) - elif [[ "${{ inputs.build_type }}" == "release" ]]; then - cov_prefix=() else cov_prefix=() fi diff --git a/.github/workflows/_benchmarking_preparation.yml b/.github/workflows/_benchmarking_preparation.yml index 7229776cd6..a52e43b4da 100644 --- a/.github/workflows/_benchmarking_preparation.yml +++ b/.github/workflows/_benchmarking_preparation.yml @@ -48,6 +48,8 @@ jobs: echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT + - uses: actions/checkout@v4 + - name: Download Neon artifact uses: ./.github/actions/download with: diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index af76e51ebc..5e9fff0e6a 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -94,11 +94,16 @@ jobs: # We run tests with addtional features, that are turned off by default (e.g. in release builds), see # corresponding Cargo.toml files for their descriptions. - name: Set env variables + env: + ARCH: ${{ inputs.arch }} run: | CARGO_FEATURES="--features testing" - if [[ $BUILD_TYPE == "debug" ]]; then + if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" CARGO_FLAGS="--locked" + elif [[ $BUILD_TYPE == "debug" ]]; then + cov_prefix="" + CARGO_FLAGS="--locked" elif [[ $BUILD_TYPE == "release" ]]; then cov_prefix="" CARGO_FLAGS="--locked --release" @@ -158,6 +163,8 @@ jobs: # Do install *before* running rust tests because they might recompile the # binaries with different features/flags. - name: Install rust binaries + env: + ARCH: ${{ inputs.arch }} run: | # Install target binaries mkdir -p /tmp/neon/bin/ @@ -172,7 +179,7 @@ jobs: done # Install test executables and write list of all binaries (for code coverage) - if [[ $BUILD_TYPE == "debug" ]]; then + if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then # Keep bloated coverage data files away from the rest of the artifact mkdir -p /tmp/coverage/ @@ -243,8 +250,8 @@ jobs: uses: ./.github/actions/save-coverage-data regress-tests: - # Run test on x64 only - if: inputs.arch == 'x64' + # Don't run regression tests on debug arm64 builds + if: inputs.build-type != 'debug' || inputs.arch != 'arm64' needs: [ build-neon ] runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', inputs.arch == 'arm64' && 'large-arm64' || 'large')) }} container: diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index ee6d3ba005..1e7f3598c2 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -198,7 +198,7 @@ jobs: strategy: fail-fast: false matrix: - arch: [ x64 ] + arch: [ x64, arm64 ] # Do not build or run tests in debug for release branches build-type: ${{ fromJson((startsWith(github.ref_name, 'release') && github.event_name == 'push') && '["release"]' || '["debug", "release"]') }} include: @@ -280,6 +280,7 @@ jobs: save_perf_report: ${{ github.ref_name == 'main' }} extra_params: --splits 5 --group ${{ matrix.pytest_split_group }} benchmark_durations: ${{ needs.get-benchmarks-durations.outputs.json }} + pg_version: v16 env: VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}" PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}" @@ -985,10 +986,10 @@ jobs: GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }} run: | if [[ "$GITHUB_REF_NAME" == "main" ]]; then - gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false + gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false gh workflow --repo neondatabase/azure run deploy.yml -f dockerTag=${{needs.tag.outputs.build-tag}} elif [[ "$GITHUB_REF_NAME" == "release" ]]; then - gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \ + gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main \ -f deployPgSniRouter=false \ -f deployProxy=false \ -f deployStorage=true \ @@ -998,14 +999,14 @@ jobs: -f dockerTag=${{needs.tag.outputs.build-tag}} \ -f deployPreprodRegion=true - gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \ + gh workflow --repo neondatabase/infra run deploy-prod.yml --ref main \ -f deployStorage=true \ -f deployStorageBroker=true \ -f deployStorageController=true \ -f branch=main \ -f dockerTag=${{needs.tag.outputs.build-tag}} elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then - gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \ + gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main \ -f deployPgSniRouter=true \ -f deployProxy=true \ -f deployStorage=false \ @@ -1015,7 +1016,7 @@ jobs: -f dockerTag=${{needs.tag.outputs.build-tag}} \ -f deployPreprodRegion=true - gh workflow --repo neondatabase/aws run deploy-proxy-prod.yml --ref main \ + gh workflow --repo neondatabase/infra run deploy-proxy-prod.yml --ref main \ -f deployPgSniRouter=true \ -f deployProxy=true \ -f branch=main \ diff --git a/Cargo.lock b/Cargo.lock index dee15b6aa7..441ca1ff86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1208,7 +1208,6 @@ dependencies = [ "serde_json", "serde_with", "utils", - "workspace_hack", ] [[package]] @@ -1321,7 +1320,6 @@ dependencies = [ "serde", "serde_with", "utils", - "workspace_hack", ] [[package]] @@ -1670,14 +1668,13 @@ dependencies = [ "smallvec", "tracing", "utils", - "workspace_hack", ] [[package]] name = "diesel" -version = "2.2.1" +version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62d6dcd069e7b5fe49a302411f759d4cf1cf2c27fe798ef46fb8baefc053dd2b" +checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71" dependencies = [ "bitflags 2.4.1", "byteorder", @@ -3147,7 +3144,6 @@ dependencies = [ "rand 0.8.5", "rand_distr", "twox-hash", - "workspace_hack", ] [[package]] @@ -3791,7 +3787,6 @@ dependencies = [ "strum_macros", "thiserror", "utils", - "workspace_hack", ] [[package]] @@ -4193,7 +4188,6 @@ dependencies = [ "tokio-rustls 0.25.0", "tokio-util", "tracing", - "workspace_hack", ] [[package]] @@ -4206,7 +4200,6 @@ dependencies = [ "postgres", "tokio-postgres", "url", - "workspace_hack", ] [[package]] @@ -4229,7 +4222,6 @@ dependencies = [ "serde", "thiserror", "utils", - "workspace_hack", ] [[package]] @@ -4267,7 +4259,6 @@ dependencies = [ "thiserror", "tokio", "tracing", - "workspace_hack", ] [[package]] @@ -4832,7 +4823,6 @@ dependencies = [ "toml_edit 0.19.10", "tracing", "utils", - "workspace_hack", ] [[package]] @@ -5357,7 +5347,6 @@ dependencies = [ "serde", "serde_with", "utils", - "workspace_hack", ] [[package]] @@ -5601,11 +5590,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -6193,7 +6183,6 @@ dependencies = [ "anyhow", "serde", "serde_json", - "workspace_hack", ] [[package]] @@ -6794,7 +6783,6 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", - "workspace_hack", ] [[package]] @@ -7012,7 +7000,6 @@ dependencies = [ "url", "uuid", "walkdir", - "workspace_hack", ] [[package]] @@ -7091,7 +7078,6 @@ dependencies = [ "postgres_ffi", "regex", "utils", - "workspace_hack", ] [[package]] @@ -7112,7 +7098,6 @@ dependencies = [ "bindgen", "postgres_ffi", "utils", - "workspace_hack", ] [[package]] @@ -7669,8 +7654,6 @@ dependencies = [ "tokio", "tokio-rustls 0.24.0", "tokio-util", - "toml_datetime", - "toml_edit 0.19.10", "tonic", "tower", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 963841e340..e038c0b4ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,7 +113,7 @@ md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } memoffset = "0.8" -nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] } +nix = { version = "0.27", features = ["dir", "fs", "process", "socket", "signal", "poll"] } notify = "6.0.0" num_cpus = "1.15" num-traits = "0.2.15" diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index f57e21ac89..2c762a21ba 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -945,7 +945,7 @@ COPY --from=hll-pg-build /hll.tar.gz /ext-src COPY --from=plpgsql-check-pg-build /plpgsql_check.tar.gz /ext-src #COPY --from=timescaledb-pg-build /timescaledb.tar.gz /ext-src COPY --from=pg-hint-plan-pg-build /pg_hint_plan.tar.gz /ext-src -COPY patches/pg_hintplan.patch /ext-src +COPY patches/pg_hint_plan.patch /ext-src COPY --from=pg-cron-pg-build /pg_cron.tar.gz /ext-src COPY patches/pg_cron.patch /ext-src #COPY --from=pg-pgx-ulid-build /home/nonroot/pgx_ulid.tar.gz /ext-src @@ -967,7 +967,7 @@ RUN cd /ext-src/pgvector-src && patch -p1 <../pgvector.patch RUN cd /ext-src/rum-src && patch -p1 <../rum.patch # cmake is required for the h3 test RUN apt-get update && apt-get install -y cmake -RUN patch -p1 < /ext-src/pg_hintplan.patch +RUN cd /ext-src/pg_hint_plan-src && patch -p1 < /ext-src/pg_hint_plan.patch COPY --chmod=755 docker-compose/run-tests.sh /run-tests.sh RUN patch -p1 =1.3](https://python-poetry.org/)) in the project directory. +Python (3.9 or higher), and install the python3 packages using `./scripts/pysync` (requires [poetry>=1.8](https://python-poetry.org/)) in the project directory. #### Running neon database @@ -262,7 +262,7 @@ By default, this runs both debug and release modes, and all supported postgres v testing locally, it is convenient to run just one set of permutations, like this: ```sh -DEFAULT_PG_VERSION=15 BUILD_TYPE=release ./scripts/pytest +DEFAULT_PG_VERSION=16 BUILD_TYPE=release ./scripts/pytest ``` ## Flamegraphs diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index edd88dc71c..1d66532d49 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -54,7 +54,7 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1); const DEFAULT_BRANCH_NAME: &str = "main"; project_git_version!(GIT_VERSION); -const DEFAULT_PG_VERSION: &str = "15"; +const DEFAULT_PG_VERSION: &str = "16"; const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/"; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 807519c88d..74caba2b56 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -27,7 +27,7 @@ use crate::pageserver::PageServerNode; use crate::pageserver::PAGESERVER_REMOTE_STORAGE_DIR; use crate::safekeeper::SafekeeperNode; -pub const DEFAULT_PG_VERSION: u32 = 15; +pub const DEFAULT_PG_VERSION: u32 = 16; // // This data structures represents neon_local CLI config diff --git a/control_plane/src/storage_controller.rs b/control_plane/src/storage_controller.rs index 2c077595a1..27d8e2de0c 100644 --- a/control_plane/src/storage_controller.rs +++ b/control_plane/src/storage_controller.rs @@ -217,7 +217,7 @@ impl StorageController { Ok(exitcode.success()) } - /// Create our database if it doesn't exist, and run migrations. + /// Create our database if it doesn't exist /// /// This function is equivalent to the `diesel setup` command in the diesel CLI. We implement /// the same steps by hand to avoid imposing a dependency on installing diesel-cli for developers @@ -382,7 +382,6 @@ impl StorageController { ) .await?; - // Run migrations on every startup, in case something changed. self.setup_database(postgres_port).await?; } @@ -454,6 +453,11 @@ impl StorageController { let jwt_token = encode_from_key_file(&claims, private_key).expect("failed to generate jwt token"); args.push(format!("--jwt-token={jwt_token}")); + + let peer_claims = Claims::new(None, Scope::Admin); + let peer_jwt_token = encode_from_key_file(&peer_claims, private_key) + .expect("failed to generate jwt token"); + args.push(format!("--peer-jwt-token={peer_jwt_token}")); } if let Some(public_key) = &self.public_key { diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index e27491c1c8..35510ccbca 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -147,9 +147,9 @@ enum Command { #[arg(long)] threshold: humantime::Duration, }, - // Drain a set of specified pageservers by moving the primary attachments to pageservers + // Migrate away from a set of specified pageservers by moving the primary attachments to pageservers // outside of the specified set. - Drain { + BulkMigrate { // Set of pageserver node ids to drain. #[arg(long)] nodes: Vec, @@ -163,6 +163,34 @@ enum Command { #[arg(long)] dry_run: Option, }, + /// Start draining the specified pageserver. + /// The drain is complete when the schedulling policy returns to active. + StartDrain { + #[arg(long)] + node_id: NodeId, + }, + /// Cancel draining the specified pageserver and wait for `timeout` + /// for the operation to be canceled. May be retried. + CancelDrain { + #[arg(long)] + node_id: NodeId, + #[arg(long)] + timeout: humantime::Duration, + }, + /// Start filling the specified pageserver. + /// The drain is complete when the schedulling policy returns to active. + StartFill { + #[arg(long)] + node_id: NodeId, + }, + /// Cancel filling the specified pageserver and wait for `timeout` + /// for the operation to be canceled. May be retried. + CancelFill { + #[arg(long)] + node_id: NodeId, + #[arg(long)] + timeout: humantime::Duration, + }, } #[derive(Parser)] @@ -249,6 +277,34 @@ impl FromStr for NodeAvailabilityArg { } } +async fn wait_for_scheduling_policy( + client: Client, + node_id: NodeId, + timeout: Duration, + f: F, +) -> anyhow::Result +where + F: Fn(NodeSchedulingPolicy) -> bool, +{ + let waiter = tokio::time::timeout(timeout, async move { + loop { + let node = client + .dispatch::<(), NodeDescribeResponse>( + Method::GET, + format!("control/v1/node/{node_id}"), + None, + ) + .await?; + + if f(node.scheduling) { + return Ok::(node.scheduling); + } + } + }); + + Ok(waiter.await??) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); @@ -628,7 +684,7 @@ async fn main() -> anyhow::Result<()> { }) .await?; } - Command::Drain { + Command::BulkMigrate { nodes, concurrency, max_shards, @@ -657,7 +713,7 @@ async fn main() -> anyhow::Result<()> { } if nodes.len() != node_to_drain_descs.len() { - anyhow::bail!("Drain requested for node which doesn't exist.") + anyhow::bail!("Bulk migration requested away from node which doesn't exist.") } node_to_fill_descs.retain(|desc| { @@ -669,7 +725,7 @@ async fn main() -> anyhow::Result<()> { }); if node_to_fill_descs.is_empty() { - anyhow::bail!("There are no nodes to drain to") + anyhow::bail!("There are no nodes to migrate to") } // Set the node scheduling policy to draining for the nodes which @@ -690,7 +746,7 @@ async fn main() -> anyhow::Result<()> { .await?; } - // Perform the drain: move each tenant shard scheduled on a node to + // Perform the migration: move each tenant shard scheduled on a node to // be drained to a node which is being filled. A simple round robin // strategy is used to pick the new node. let tenants = storcon_client @@ -703,13 +759,13 @@ async fn main() -> anyhow::Result<()> { let mut selected_node_idx = 0; - struct DrainMove { + struct MigrationMove { tenant_shard_id: TenantShardId, from: NodeId, to: NodeId, } - let mut moves: Vec = Vec::new(); + let mut moves: Vec = Vec::new(); let shards = tenants .into_iter() @@ -739,7 +795,7 @@ async fn main() -> anyhow::Result<()> { continue; } - moves.push(DrainMove { + moves.push(MigrationMove { tenant_shard_id: shard.tenant_shard_id, from: shard .node_attached @@ -816,6 +872,67 @@ async fn main() -> anyhow::Result<()> { failure ); } + Command::StartDrain { node_id } => { + storcon_client + .dispatch::<(), ()>( + Method::PUT, + format!("control/v1/node/{node_id}/drain"), + None, + ) + .await?; + println!("Drain started for {node_id}"); + } + Command::CancelDrain { node_id, timeout } => { + storcon_client + .dispatch::<(), ()>( + Method::DELETE, + format!("control/v1/node/{node_id}/drain"), + None, + ) + .await?; + + println!("Waiting for node {node_id} to quiesce on scheduling policy ..."); + + let final_policy = + wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| { + use NodeSchedulingPolicy::*; + matches!(sched, Active | PauseForRestart) + }) + .await?; + + println!( + "Drain was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}" + ); + } + Command::StartFill { node_id } => { + storcon_client + .dispatch::<(), ()>(Method::PUT, format!("control/v1/node/{node_id}/fill"), None) + .await?; + + println!("Fill started for {node_id}"); + } + Command::CancelFill { node_id, timeout } => { + storcon_client + .dispatch::<(), ()>( + Method::DELETE, + format!("control/v1/node/{node_id}/fill"), + None, + ) + .await?; + + println!("Waiting for node {node_id} to quiesce on scheduling policy ..."); + + let final_policy = + wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| { + use NodeSchedulingPolicy::*; + matches!(sched, Active) + }) + .await?; + + println!( + "Fill was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}" + ); + } } Ok(()) diff --git a/docker-compose/run-tests.sh b/docker-compose/run-tests.sh index 58b2581197..3fc0b90071 100644 --- a/docker-compose/run-tests.sh +++ b/docker-compose/run-tests.sh @@ -3,7 +3,7 @@ set -x cd /ext-src || exit 2 FAILED= -LIST=$( (echo "${SKIP//","/"\n"}"; ls -d -- *-src) | sort | uniq -u) +LIST=$( (echo -e "${SKIP//","/"\n"}"; ls -d -- *-src) | sort | uniq -u) for d in ${LIST} do [ -d "${d}" ] || continue diff --git a/docs/rfcs/033-storage-controller-drain-and-fill.md b/docs/rfcs/033-storage-controller-drain-and-fill.md index 77c84cd2a5..733f7c0bd8 100644 --- a/docs/rfcs/033-storage-controller-drain-and-fill.md +++ b/docs/rfcs/033-storage-controller-drain-and-fill.md @@ -14,7 +14,7 @@ picked tenant (which requested on-demand activation) for around 30 seconds during the restart at 2024-04-03 16:37 UTC. Note that lots of shutdowns on loaded pageservers do not finish within the -[10 second systemd enforced timeout](https://github.com/neondatabase/aws/blob/0a5280b383e43c063d43cbf87fa026543f6d6ad4/.github/ansible/systemd/pageserver.service#L16). This means we are shutting down without flushing ephemeral layers +[10 second systemd enforced timeout](https://github.com/neondatabase/infra/blob/0a5280b383e43c063d43cbf87fa026543f6d6ad4/.github/ansible/systemd/pageserver.service#L16). This means we are shutting down without flushing ephemeral layers and have to reingest data in order to serve requests after restarting, potentially making first request latencies worse. This problem is not yet very acutely felt in storage controller managed pageservers since diff --git a/docs/updating-postgres.md b/docs/updating-postgres.md index 1868bbf5f7..7913b0a9e2 100644 --- a/docs/updating-postgres.md +++ b/docs/updating-postgres.md @@ -21,30 +21,21 @@ _Example: 15.4 is the new minor version to upgrade to from 15.3._ 1. Create a new branch based on the stable branch you are updating. ```shell - git checkout -b my-branch REL_15_STABLE_neon + git checkout -b my-branch-15 REL_15_STABLE_neon ``` -1. Tag the last commit on the stable branch you are updating. +1. Find the upstream release tags you're looking for. They are of the form `REL_X_Y`. - ```shell - git tag REL_15_3_neon - ``` - -1. Push the new tag to the Neon Postgres repository. - - ```shell - git push origin REL_15_3_neon - ``` - -1. Find the release tags you're looking for. They are of the form `REL_X_Y`. - -1. Rebase the branch you created on the tag and resolve any conflicts. +1. Merge the upstream tag into the branch you created on the tag and resolve any conflicts. ```shell git fetch upstream REL_15_4 - git rebase REL_15_4 + git merge REL_15_4 ``` + In the commit message of the merge commit, mention if there were + any non-trivial conflicts or other issues. + 1. Run the Postgres test suite to make sure our commits have not affected Postgres in a negative way. @@ -57,7 +48,7 @@ Postgres in a negative way. 1. Push your branch to the Neon Postgres repository. ```shell - git push origin my-branch + git push origin my-branch-15 ``` 1. Clone the Neon repository if you have not done so already. @@ -74,7 +65,7 @@ branch. 1. Update the Git submodule. ```shell - git submodule set-branch --branch my-branch vendor/postgres-v15 + git submodule set-branch --branch my-branch-15 vendor/postgres-v15 git submodule update --remote vendor/postgres-v15 ``` @@ -89,14 +80,12 @@ minor Postgres release. 1. Create a pull request, and wait for CI to go green. -1. Force push the rebased Postgres branches into the Neon Postgres repository. +1. Push the Postgres branches with the merge commits into the Neon Postgres repository. ```shell - git push --force origin my-branch:REL_15_STABLE_neon + git push origin my-branch-15:REL_15_STABLE_neon ``` - It may require disabling various branch protections. - 1. Update your Neon PR to point at the branches. ```shell diff --git a/libs/compute_api/Cargo.toml b/libs/compute_api/Cargo.toml index b377bd2cce..8aaa481f8c 100644 --- a/libs/compute_api/Cargo.toml +++ b/libs/compute_api/Cargo.toml @@ -14,5 +14,3 @@ regex.workspace = true utils = { path = "../utils" } remote_storage = { version = "0.1", path = "../remote_storage/" } - -workspace_hack.workspace = true diff --git a/libs/consumption_metrics/Cargo.toml b/libs/consumption_metrics/Cargo.toml index 3f290821c2..a40b74b952 100644 --- a/libs/consumption_metrics/Cargo.toml +++ b/libs/consumption_metrics/Cargo.toml @@ -6,10 +6,8 @@ license = "Apache-2.0" [dependencies] anyhow.workspace = true -chrono.workspace = true +chrono = { workspace = true, features = ["serde"] } rand.workspace = true serde.workspace = true serde_with.workspace = true utils.workspace = true - -workspace_hack.workspace = true diff --git a/libs/desim/Cargo.toml b/libs/desim/Cargo.toml index 6f442d8243..0c4be90267 100644 --- a/libs/desim/Cargo.toml +++ b/libs/desim/Cargo.toml @@ -14,5 +14,3 @@ parking_lot.workspace = true hex.workspace = true scopeguard.workspace = true smallvec = { workspace = true, features = ["write"] } - -workspace_hack.workspace = true diff --git a/libs/metrics/Cargo.toml b/libs/metrics/Cargo.toml index 0bd804051c..f87e7b8e3a 100644 --- a/libs/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -12,8 +12,6 @@ chrono.workspace = true twox-hash.workspace = true measured.workspace = true -workspace_hack.workspace = true - [target.'cfg(target_os = "linux")'.dependencies] procfs.workspace = true measured-process.workspace = true diff --git a/libs/pageserver_api/Cargo.toml b/libs/pageserver_api/Cargo.toml index 3bba89c76d..cb28359ac3 100644 --- a/libs/pageserver_api/Cargo.toml +++ b/libs/pageserver_api/Cargo.toml @@ -21,11 +21,9 @@ hex.workspace = true humantime.workspace = true thiserror.workspace = true humantime-serde.workspace = true -chrono.workspace = true +chrono = { workspace = true, features = ["serde"] } itertools.workspace = true -workspace_hack.workspace = true - [dev-dependencies] bincode.workspace = true rand.workspace = true diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index a50707a1b8..a9a57d77ce 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -8,6 +8,7 @@ use std::time::{Duration, Instant}; use serde::{Deserialize, Serialize}; use utils::id::{NodeId, TenantId}; +use crate::models::PageserverUtilization; use crate::{ models::{ShardParameters, TenantConfig}, shard::{ShardStripeSize, TenantShardId}, @@ -140,23 +141,11 @@ pub struct TenantShardMigrateRequest { pub node_id: NodeId, } -/// Utilisation score indicating how good a candidate a pageserver -/// is for scheduling the next tenant. See [`crate::models::PageserverUtilization`]. -/// Lower values are better. -#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Debug)] -pub struct UtilizationScore(pub u64); - -impl UtilizationScore { - pub fn worst() -> Self { - UtilizationScore(u64::MAX) - } -} - -#[derive(Serialize, Clone, Copy, Debug)] +#[derive(Serialize, Clone, Debug)] #[serde(into = "NodeAvailabilityWrapper")] pub enum NodeAvailability { // Normal, happy state - Active(UtilizationScore), + Active(PageserverUtilization), // Node is warming up, but we expect it to become available soon. Covers // the time span between the re-attach response being composed on the storage controller // and the first successful heartbeat after the processing of the re-attach response @@ -195,7 +184,9 @@ impl From for NodeAvailability { match val { // Assume the worst utilisation score to begin with. It will later be updated by // the heartbeats. - NodeAvailabilityWrapper::Active => NodeAvailability::Active(UtilizationScore::worst()), + NodeAvailabilityWrapper::Active => { + NodeAvailability::Active(PageserverUtilization::full()) + } NodeAvailabilityWrapper::WarmingUp => NodeAvailability::WarmingUp(Instant::now()), NodeAvailabilityWrapper::Offline => NodeAvailability::Offline, } diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 2fdd7de38f..77da58d63e 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -236,6 +236,15 @@ impl Key { field5: u8::MAX, field6: u32::MAX, }; + /// A key slightly smaller than [`Key::MAX`] for use in layer key ranges to avoid them to be confused with L0 layers + pub const NON_L0_MAX: Key = Key { + field1: u8::MAX, + field2: u32::MAX, + field3: u32::MAX, + field4: u32::MAX, + field5: u8::MAX, + field6: u32::MAX - 1, + }; pub fn from_hex(s: &str) -> Result { if s.len() != 36 { diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index ab4adfbebe..d39ac75707 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -348,7 +348,7 @@ impl AuxFilePolicy { /// If a tenant writes aux files without setting `switch_aux_policy`, this value will be used. pub fn default_tenant_config() -> Self { - Self::V1 + Self::V2 } } @@ -718,6 +718,7 @@ pub struct TimelineInfo { pub pg_version: u32, pub state: TimelineState, + pub is_archived: bool, pub walreceiver_status: String, @@ -1062,7 +1063,7 @@ impl TryFrom for PagestreamBeMessageTag { } } -// In the V2 protocol version, a GetPage request contains two LSN values: +// A GetPage request contains two LSN values: // // request_lsn: Get the page version at this point in time. Lsn::Max is a special value that means // "get the latest version present". It's used by the primary server, which knows that no one else @@ -1075,7 +1076,7 @@ impl TryFrom for PagestreamBeMessageTag { // passing an earlier LSN can speed up the request, by allowing the pageserver to process the // request without waiting for 'request_lsn' to arrive. // -// The legacy V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was +// The now-defunct V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was // sufficient for the primary; the 'lsn' was equivalent to the 'not_modified_since' value, and // 'latest' was set to true. The V2 interface was added because there was no correct way for a // standby to request a page at a particular non-latest LSN, and also include the @@ -1083,15 +1084,11 @@ impl TryFrom for PagestreamBeMessageTag { // request, if the standby knows that the page hasn't been modified since, and risk getting an error // if that LSN has fallen behind the GC horizon, or requesting the current replay LSN, which could // require the pageserver unnecessarily to wait for the WAL to arrive up to that point. The new V2 -// interface allows sending both LSNs, and let the pageserver do the right thing. There is no +// interface allows sending both LSNs, and let the pageserver do the right thing. There was no // difference in the responses between V1 and V2. // -// The Request structs below reflect the V2 interface. If V1 is used, the parse function -// maps the old format requests to the new format. -// #[derive(Clone, Copy)] pub enum PagestreamProtocolVersion { - V1, V2, } @@ -1230,36 +1227,17 @@ impl PagestreamFeMessage { bytes.into() } - pub fn parse( - body: &mut R, - protocol_version: PagestreamProtocolVersion, - ) -> anyhow::Result { + pub fn parse(body: &mut R) -> anyhow::Result { // these correspond to the NeonMessageTag enum in pagestore_client.h // // TODO: consider using protobuf or serde bincode for less error prone // serialization. let msg_tag = body.read_u8()?; - let (request_lsn, not_modified_since) = match protocol_version { - PagestreamProtocolVersion::V2 => ( - Lsn::from(body.read_u64::()?), - Lsn::from(body.read_u64::()?), - ), - PagestreamProtocolVersion::V1 => { - // In the old protocol, each message starts with a boolean 'latest' flag, - // followed by 'lsn'. Convert that to the two LSNs, 'request_lsn' and - // 'not_modified_since', used in the new protocol version. - let latest = body.read_u8()? != 0; - let request_lsn = Lsn::from(body.read_u64::()?); - if latest { - (Lsn::MAX, request_lsn) // get latest version - } else { - (request_lsn, request_lsn) // get version at specified LSN - } - } - }; + // these two fields are the same for every request type + let request_lsn = Lsn::from(body.read_u64::()?); + let not_modified_since = Lsn::from(body.read_u64::()?); - // The rest of the messages are the same between V1 and V2 match msg_tag { 0 => Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest { request_lsn, @@ -1467,9 +1445,7 @@ mod tests { ]; for msg in messages { let bytes = msg.serialize(); - let reconstructed = - PagestreamFeMessage::parse(&mut bytes.reader(), PagestreamProtocolVersion::V2) - .unwrap(); + let reconstructed = PagestreamFeMessage::parse(&mut bytes.reader()).unwrap(); assert!(msg == reconstructed); } } diff --git a/libs/pageserver_api/src/models/utilization.rs b/libs/pageserver_api/src/models/utilization.rs index 0fec221276..844a0cda5d 100644 --- a/libs/pageserver_api/src/models/utilization.rs +++ b/libs/pageserver_api/src/models/utilization.rs @@ -38,7 +38,7 @@ pub struct PageserverUtilization { pub max_shard_count: u32, /// Cached result of [`Self::score`] - pub utilization_score: u64, + pub utilization_score: Option, /// When was this snapshot captured, pageserver local time. /// @@ -50,6 +50,8 @@ fn unity_percent() -> Percent { Percent::new(0).unwrap() } +pub type RawScore = u64; + impl PageserverUtilization { const UTILIZATION_FULL: u64 = 1000000; @@ -62,7 +64,7 @@ impl PageserverUtilization { /// - Negative values are forbidden /// - Values over UTILIZATION_FULL indicate an overloaded node, which may show degraded performance due to /// layer eviction. - pub fn score(&self) -> u64 { + pub fn score(&self) -> RawScore { let disk_usable_capacity = ((self.disk_usage_bytes + self.free_space_bytes) * self.disk_usable_pct.get() as u64) / 100; @@ -74,8 +76,30 @@ impl PageserverUtilization { std::cmp::max(disk_utilization_score, shard_utilization_score) } - pub fn refresh_score(&mut self) { - self.utilization_score = self.score(); + pub fn cached_score(&mut self) -> RawScore { + match self.utilization_score { + None => { + let s = self.score(); + self.utilization_score = Some(s); + s + } + Some(s) => s, + } + } + + /// If a node is currently hosting more work than it can comfortably handle. This does not indicate that + /// it will fail, but it is a strong signal that more work should not be added unless there is no alternative. + pub fn is_overloaded(score: RawScore) -> bool { + score >= Self::UTILIZATION_FULL + } + + pub fn adjust_shard_count_max(&mut self, shard_count: u32) { + if self.shard_count < shard_count { + self.shard_count = shard_count; + + // Dirty cache: this will be calculated next time someone retrives the score + self.utilization_score = None; + } } /// A utilization structure that has a full utilization score: use this as a placeholder when @@ -88,7 +112,38 @@ impl PageserverUtilization { disk_usable_pct: Percent::new(100).unwrap(), shard_count: 1, max_shard_count: 1, - utilization_score: Self::UTILIZATION_FULL, + utilization_score: Some(Self::UTILIZATION_FULL), + captured_at: serde_system_time::SystemTime(SystemTime::now()), + } + } +} + +/// Test helper +pub mod test_utilization { + use super::PageserverUtilization; + use std::time::SystemTime; + use utils::{ + serde_percent::Percent, + serde_system_time::{self}, + }; + + // Parameters of the imaginary node used for test utilization instances + const TEST_DISK_SIZE: u64 = 1024 * 1024 * 1024 * 1024; + const TEST_SHARDS_MAX: u32 = 1000; + + /// Unit test helper. Unconditionally compiled because cfg(test) doesn't carry across crates. Do + /// not abuse this function from non-test code. + /// + /// Emulates a node with a 1000 shard limit and a 1TB disk. + pub fn simple(shard_count: u32, disk_wanted_bytes: u64) -> PageserverUtilization { + PageserverUtilization { + disk_usage_bytes: disk_wanted_bytes, + free_space_bytes: TEST_DISK_SIZE - std::cmp::min(disk_wanted_bytes, TEST_DISK_SIZE), + disk_wanted_bytes, + disk_usable_pct: Percent::new(100).unwrap(), + shard_count, + max_shard_count: TEST_SHARDS_MAX, + utilization_score: None, captured_at: serde_system_time::SystemTime(SystemTime::now()), } } @@ -120,7 +175,7 @@ mod tests { disk_usage_bytes: u64::MAX, free_space_bytes: 0, disk_wanted_bytes: u64::MAX, - utilization_score: 13, + utilization_score: Some(13), disk_usable_pct: Percent::new(90).unwrap(), shard_count: 100, max_shard_count: 200, diff --git a/libs/postgres_backend/Cargo.toml b/libs/postgres_backend/Cargo.toml index c7611b9f21..f6854328fc 100644 --- a/libs/postgres_backend/Cargo.toml +++ b/libs/postgres_backend/Cargo.toml @@ -18,7 +18,6 @@ tokio-rustls.workspace = true tracing.workspace = true pq_proto.workspace = true -workspace_hack.workspace = true [dev-dependencies] once_cell.workspace = true diff --git a/libs/postgres_connection/Cargo.toml b/libs/postgres_connection/Cargo.toml index fbfea80ae2..19027d13ff 100644 --- a/libs/postgres_connection/Cargo.toml +++ b/libs/postgres_connection/Cargo.toml @@ -11,7 +11,5 @@ postgres.workspace = true tokio-postgres.workspace = true url.workspace = true -workspace_hack.workspace = true - [dev-dependencies] once_cell.workspace = true diff --git a/libs/postgres_ffi/Cargo.toml b/libs/postgres_ffi/Cargo.toml index 86e72f6bdd..ee69878f69 100644 --- a/libs/postgres_ffi/Cargo.toml +++ b/libs/postgres_ffi/Cargo.toml @@ -19,8 +19,6 @@ thiserror.workspace = true serde.workspace = true utils.workspace = true -workspace_hack.workspace = true - [dev-dependencies] env_logger.workspace = true postgres.workspace = true diff --git a/libs/postgres_ffi/wal_craft/Cargo.toml b/libs/postgres_ffi/wal_craft/Cargo.toml index 0edc642402..29dd01a936 100644 --- a/libs/postgres_ffi/wal_craft/Cargo.toml +++ b/libs/postgres_ffi/wal_craft/Cargo.toml @@ -14,8 +14,6 @@ postgres.workspace = true postgres_ffi.workspace = true camino-tempfile.workspace = true -workspace_hack.workspace = true - [dev-dependencies] regex.workspace = true utils.workspace = true diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index 8afabe670e..66bbe03ebc 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -11,9 +11,7 @@ itertools.workspace = true pin-project-lite.workspace = true postgres-protocol.workspace = true rand.workspace = true -tokio.workspace = true +tokio = { workspace = true, features = ["io-util"] } tracing.workspace = true thiserror.workspace = true serde.workspace = true - -workspace_hack.workspace = true diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 414bce1b26..02adee058f 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -32,7 +32,7 @@ scopeguard.workspace = true metrics.workspace = true utils.workspace = true pin-project-lite.workspace = true -workspace_hack.workspace = true + azure_core.workspace = true azure_identity.workspace = true azure_storage.workspace = true @@ -46,3 +46,4 @@ sync_wrapper = { workspace = true, features = ["futures"] } camino-tempfile.workspace = true test-context.workspace = true rand.workspace = true +tokio = { workspace = true, features = ["test-util"] } diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index 3c77d5a227..cb7479f6cd 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -383,6 +383,48 @@ impl RemoteStorage for AzureBlobStorage { } } + async fn head_object( + &self, + key: &RemotePath, + cancel: &CancellationToken, + ) -> Result { + let kind = RequestKind::Head; + let _permit = self.permit(kind, cancel).await?; + + let started_at = start_measuring_requests(kind); + + let blob_client = self.client.blob_client(self.relative_path_to_name(key)); + let properties_future = blob_client.get_properties().into_future(); + + let properties_future = tokio::time::timeout(self.timeout, properties_future); + + let res = tokio::select! { + res = properties_future => res, + _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()), + }; + + if let Ok(inner) = &res { + // do not incl. timeouts as errors in metrics but cancellations + let started_at = ScopeGuard::into_inner(started_at); + crate::metrics::BUCKET_METRICS + .req_seconds + .observe_elapsed(kind, inner, started_at); + } + + let data = match res { + Ok(Ok(data)) => Ok(data), + Ok(Err(sdk)) => Err(to_download_error(sdk)), + Err(_timeout) => Err(DownloadError::Timeout), + }?; + + let properties = data.blob.properties; + Ok(ListingObject { + key: key.to_owned(), + last_modified: SystemTime::from(properties.last_modified), + size: properties.content_length, + }) + } + async fn upload( &self, from: impl Stream> + Send + Sync + 'static, diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 2c9e298f79..cc1d3e0ae4 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -150,7 +150,7 @@ pub enum ListingMode { NoDelimiter, } -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] pub struct ListingObject { pub key: RemotePath, pub last_modified: SystemTime, @@ -215,6 +215,13 @@ pub trait RemoteStorage: Send + Sync + 'static { Ok(combined) } + /// Obtain metadata information about an object. + async fn head_object( + &self, + key: &RemotePath, + cancel: &CancellationToken, + ) -> Result; + /// Streams the local file contents into remote into the remote storage entry. /// /// If the operation fails because of timeout or cancellation, the root cause of the error will be @@ -363,6 +370,20 @@ impl GenericRemoteStorage> { } } + // See [`RemoteStorage::head_object`]. + pub async fn head_object( + &self, + key: &RemotePath, + cancel: &CancellationToken, + ) -> Result { + match self { + Self::LocalFs(s) => s.head_object(key, cancel).await, + Self::AwsS3(s) => s.head_object(key, cancel).await, + Self::AzureBlob(s) => s.head_object(key, cancel).await, + Self::Unreliable(s) => s.head_object(key, cancel).await, + } + } + /// See [`RemoteStorage::upload`] pub async fn upload( &self, @@ -598,6 +619,7 @@ impl ConcurrencyLimiter { RequestKind::Delete => &self.write, RequestKind::Copy => &self.write, RequestKind::TimeTravel => &self.write, + RequestKind::Head => &self.read, } } diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 99b4aa4061..c3ef18cab1 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -445,6 +445,20 @@ impl RemoteStorage for LocalFs { } } + async fn head_object( + &self, + key: &RemotePath, + _cancel: &CancellationToken, + ) -> Result { + let target_file_path = key.with_base(&self.storage_root); + let metadata = file_metadata(&target_file_path).await?; + Ok(ListingObject { + key: key.clone(), + last_modified: metadata.modified()?, + size: metadata.len(), + }) + } + async fn upload( &self, data: impl Stream> + Send + Sync, diff --git a/libs/remote_storage/src/metrics.rs b/libs/remote_storage/src/metrics.rs index bbb51590f3..f1aa4c433b 100644 --- a/libs/remote_storage/src/metrics.rs +++ b/libs/remote_storage/src/metrics.rs @@ -13,6 +13,7 @@ pub(crate) enum RequestKind { List = 3, Copy = 4, TimeTravel = 5, + Head = 6, } use scopeguard::ScopeGuard; @@ -27,6 +28,7 @@ impl RequestKind { List => "list_objects", Copy => "copy_object", TimeTravel => "time_travel_recover", + Head => "head_object", } } const fn as_index(&self) -> usize { @@ -34,7 +36,8 @@ impl RequestKind { } } -pub(crate) struct RequestTyped([C; 6]); +const REQUEST_KIND_COUNT: usize = 7; +pub(crate) struct RequestTyped([C; REQUEST_KIND_COUNT]); impl RequestTyped { pub(crate) fn get(&self, kind: RequestKind) -> &C { @@ -43,8 +46,8 @@ impl RequestTyped { fn build_with(mut f: impl FnMut(RequestKind) -> C) -> Self { use RequestKind::*; - let mut it = [Get, Put, Delete, List, Copy, TimeTravel].into_iter(); - let arr = std::array::from_fn::(|index| { + let mut it = [Get, Put, Delete, List, Copy, TimeTravel, Head].into_iter(); + let arr = std::array::from_fn::(|index| { let next = it.next().unwrap(); assert_eq!(index, next.as_index()); f(next) diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index 1f25da813d..11f6598cbf 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -23,7 +23,7 @@ use aws_config::{ use aws_sdk_s3::{ config::{AsyncSleep, IdentityCache, Region, SharedAsyncSleep}, error::SdkError, - operation::get_object::GetObjectError, + operation::{get_object::GetObjectError, head_object::HeadObjectError}, types::{Delete, DeleteMarkerEntry, ObjectIdentifier, ObjectVersion, StorageClass}, Client, }; @@ -604,6 +604,78 @@ impl RemoteStorage for S3Bucket { } } + async fn head_object( + &self, + key: &RemotePath, + cancel: &CancellationToken, + ) -> Result { + let kind = RequestKind::Head; + let _permit = self.permit(kind, cancel).await?; + + let started_at = start_measuring_requests(kind); + + let head_future = self + .client + .head_object() + .bucket(self.bucket_name()) + .key(self.relative_path_to_s3_object(key)) + .send(); + + let head_future = tokio::time::timeout(self.timeout, head_future); + + let res = tokio::select! { + res = head_future => res, + _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()), + }; + + let res = res.map_err(|_e| DownloadError::Timeout)?; + + // do not incl. timeouts as errors in metrics but cancellations + let started_at = ScopeGuard::into_inner(started_at); + crate::metrics::BUCKET_METRICS + .req_seconds + .observe_elapsed(kind, &res, started_at); + + let data = match res { + Ok(object_output) => object_output, + Err(SdkError::ServiceError(e)) if matches!(e.err(), HeadObjectError::NotFound(_)) => { + // Count this in the AttemptOutcome::Ok bucket, because 404 is not + // an error: we expect to sometimes fetch an object and find it missing, + // e.g. when probing for timeline indices. + crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed( + kind, + AttemptOutcome::Ok, + started_at, + ); + return Err(DownloadError::NotFound); + } + Err(e) => { + crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed( + kind, + AttemptOutcome::Err, + started_at, + ); + + return Err(DownloadError::Other( + anyhow::Error::new(e).context("s3 head object"), + )); + } + }; + + let (Some(last_modified), Some(size)) = (data.last_modified, data.content_length) else { + return Err(DownloadError::Other(anyhow!( + "head_object doesn't contain last_modified or content_length" + )))?; + }; + Ok(ListingObject { + key: key.to_owned(), + last_modified: SystemTime::try_from(last_modified).map_err(|e| { + DownloadError::Other(anyhow!("can't convert time '{last_modified}': {e}")) + })?, + size: size as u64, + }) + } + async fn upload( &self, from: impl Stream> + Send + Sync + 'static, diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 13f873dcdb..c7eb634af3 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -30,6 +30,7 @@ pub struct UnreliableWrapper { #[derive(Debug, Hash, Eq, PartialEq)] enum RemoteOp { ListPrefixes(Option), + HeadObject(RemotePath), Upload(RemotePath), Download(RemotePath), Delete(RemotePath), @@ -137,6 +138,16 @@ impl RemoteStorage for UnreliableWrapper { self.inner.list(prefix, mode, max_keys, cancel).await } + async fn head_object( + &self, + key: &RemotePath, + cancel: &CancellationToken, + ) -> Result { + self.attempt(RemoteOp::HeadObject(key.clone())) + .map_err(DownloadError::Other)?; + self.inner.head_object(key, cancel).await + } + async fn upload( &self, data: impl Stream> + Send + Sync + 'static, diff --git a/libs/safekeeper_api/Cargo.toml b/libs/safekeeper_api/Cargo.toml index 327d98ee77..e1f4bcca46 100644 --- a/libs/safekeeper_api/Cargo.toml +++ b/libs/safekeeper_api/Cargo.toml @@ -9,5 +9,3 @@ serde.workspace = true serde_with.workspace = true const_format.workspace = true utils.workspace = true - -workspace_hack.workspace = true diff --git a/libs/tenant_size_model/Cargo.toml b/libs/tenant_size_model/Cargo.toml index 15e78932a8..8aa3c54f62 100644 --- a/libs/tenant_size_model/Cargo.toml +++ b/libs/tenant_size_model/Cargo.toml @@ -9,5 +9,3 @@ license.workspace = true anyhow.workspace = true serde.workspace = true serde_json.workspace = true - -workspace_hack.workspace = true diff --git a/libs/tracing-utils/Cargo.toml b/libs/tracing-utils/Cargo.toml index 512a748124..5ea8db6b42 100644 --- a/libs/tracing-utils/Cargo.toml +++ b/libs/tracing-utils/Cargo.toml @@ -14,5 +14,3 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracing.workspace = true tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true - -workspace_hack.workspace = true diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index ec05f849cf..6e593eeac1 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -39,7 +39,7 @@ thiserror.workspace = true tokio.workspace = true tokio-tar.workspace = true tokio-util.workspace = true -toml_edit.workspace = true +toml_edit = { workspace = true, features = ["serde"] } tracing.workspace = true tracing-error.workspace = true tracing-subscriber = { workspace = true, features = ["json", "registry"] } @@ -54,7 +54,6 @@ walkdir.workspace = true pq_proto.workspace = true postgres_connection.workspace = true metrics.workspace = true -workspace_hack.workspace = true const_format.workspace = true @@ -71,6 +70,7 @@ criterion.workspace = true hex-literal.workspace = true camino-tempfile.workspace = true serde_assert.workspace = true +tokio = { workspace = true, features = ["test-util"] } [[bench]] name = "benchmarks" diff --git a/libs/walproposer/Cargo.toml b/libs/walproposer/Cargo.toml index 73aa073c44..2d442dc429 100644 --- a/libs/walproposer/Cargo.toml +++ b/libs/walproposer/Cargo.toml @@ -9,8 +9,6 @@ anyhow.workspace = true utils.workspace = true postgres_ffi.workspace = true -workspace_hack.workspace = true - [build-dependencies] anyhow.workspace = true bindgen.workspace = true diff --git a/libs/walproposer/build.rs b/libs/walproposer/build.rs index 3126b170a4..7bb077062b 100644 --- a/libs/walproposer/build.rs +++ b/libs/walproposer/build.rs @@ -95,6 +95,7 @@ fn main() -> anyhow::Result<()> { .allowlist_var("ERROR") .allowlist_var("FATAL") .allowlist_var("PANIC") + .allowlist_var("PG_VERSION_NUM") .allowlist_var("WPEVENT") .allowlist_var("WL_LATCH_SET") .allowlist_var("WL_SOCKET_READABLE") diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 37b1e0fa87..ba75171db2 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -282,7 +282,11 @@ mod tests { use std::cell::UnsafeCell; use utils::id::TenantTimelineId; - use crate::{api_bindings::Level, bindings::NeonWALReadResult, walproposer::Wrapper}; + use crate::{ + api_bindings::Level, + bindings::{NeonWALReadResult, PG_VERSION_NUM}, + walproposer::Wrapper, + }; use super::ApiImpl; @@ -489,41 +493,79 @@ mod tests { let (sender, receiver) = sync_channel(1); + // Messages definitions are at walproposer.h + // xxx: it would be better to extract them from safekeeper crate and + // use serialization/deserialization here. + let greeting_tag = (b'g' as u64).to_ne_bytes(); + let proto_version = 2_u32.to_ne_bytes(); + let pg_version: [u8; 4] = PG_VERSION_NUM.to_ne_bytes(); + let proposer_id = [0; 16]; + let system_id = 0_u64.to_ne_bytes(); + let tenant_id = ttid.tenant_id.as_arr(); + let timeline_id = ttid.timeline_id.as_arr(); + let pg_tli = 1_u32.to_ne_bytes(); + let wal_seg_size = 16777216_u32.to_ne_bytes(); + let proposer_greeting = [ + greeting_tag.as_slice(), + proto_version.as_slice(), + pg_version.as_slice(), + proposer_id.as_slice(), + system_id.as_slice(), + tenant_id.as_slice(), + timeline_id.as_slice(), + pg_tli.as_slice(), + wal_seg_size.as_slice(), + ] + .concat(); + + let voting_tag = (b'v' as u64).to_ne_bytes(); + let vote_request_term = 3_u64.to_ne_bytes(); + let proposer_id = [0; 16]; + let vote_request = [ + voting_tag.as_slice(), + vote_request_term.as_slice(), + proposer_id.as_slice(), + ] + .concat(); + + let acceptor_greeting_term = 2_u64.to_ne_bytes(); + let acceptor_greeting_node_id = 1_u64.to_ne_bytes(); + let acceptor_greeting = [ + greeting_tag.as_slice(), + acceptor_greeting_term.as_slice(), + acceptor_greeting_node_id.as_slice(), + ] + .concat(); + + let vote_response_term = 3_u64.to_ne_bytes(); + let vote_given = 1_u64.to_ne_bytes(); + let flush_lsn = 0x539_u64.to_ne_bytes(); + let truncate_lsn = 0x539_u64.to_ne_bytes(); + let th_len = 1_u32.to_ne_bytes(); + let th_term = 2_u64.to_ne_bytes(); + let th_lsn = 0x539_u64.to_ne_bytes(); + let timeline_start_lsn = 0x539_u64.to_ne_bytes(); + let vote_response = [ + voting_tag.as_slice(), + vote_response_term.as_slice(), + vote_given.as_slice(), + flush_lsn.as_slice(), + truncate_lsn.as_slice(), + th_len.as_slice(), + th_term.as_slice(), + th_lsn.as_slice(), + timeline_start_lsn.as_slice(), + ] + .concat(); + let my_impl: Box = Box::new(MockImpl { wait_events: Cell::new(WaitEventsData { sk: std::ptr::null_mut(), event_mask: 0, }), - expected_messages: vec![ - // TODO: When updating Postgres versions, this test will cause - // problems. Postgres version in message needs updating. - // - // Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160003, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 }) - vec![ - 103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110, - 147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147, - 188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1, - ], - // VoteRequest(VoteRequest { term: 3 }) - vec![ - 118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ], - ], + expected_messages: vec![proposer_greeting, vote_request], expected_ptr: AtomicUsize::new(0), - safekeeper_replies: vec![ - // Greeting(AcceptorGreeting { term: 2, node_id: NodeId(1) }) - vec![ - 103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, - ], - // VoteResponse(VoteResponse { term: 3, vote_given: 1, flush_lsn: 0/539, truncate_lsn: 0/539, term_history: [(2, 0/539)], timeline_start_lsn: 0/539 }) - vec![ - 118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 57, - 5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, - 0, 57, 5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0, - ], - ], + safekeeper_replies: vec![acceptor_greeting, vote_response], replies_ptr: AtomicUsize::new(0), sync_channel: sender, shmem: UnsafeCell::new(crate::api_bindings::empty_shmem()), diff --git a/pageserver/benches/bench_ingest.rs b/pageserver/benches/bench_ingest.rs index 0336302de0..bd99f5289d 100644 --- a/pageserver/benches/bench_ingest.rs +++ b/pageserver/benches/bench_ingest.rs @@ -10,6 +10,7 @@ use pageserver::{ page_cache, repository::Value, task_mgr::TaskKind, + tenant::storage_layer::inmemory_layer::SerializedBatch, tenant::storage_layer::InMemoryLayer, virtual_file, }; @@ -67,12 +68,16 @@ async fn ingest( let layer = InMemoryLayer::create(conf, timeline_id, tenant_shard_id, lsn, entered, &ctx).await?; - let data = Value::Image(Bytes::from(vec![0u8; put_size])).ser()?; + let data = Value::Image(Bytes::from(vec![0u8; put_size])); + let data_ser_size = data.serialized_size().unwrap() as usize; let ctx = RequestContext::new( pageserver::task_mgr::TaskKind::WalReceiverConnectionHandler, pageserver::context::DownloadBehavior::Download, ); + const BATCH_SIZE: usize = 16; + let mut batch = Vec::new(); + for i in 0..put_count { lsn += put_size as u64; @@ -95,7 +100,17 @@ async fn ingest( } } - layer.put_value(key.to_compact(), lsn, &data, &ctx).await?; + batch.push((key.to_compact(), lsn, data_ser_size, data.clone())); + if batch.len() >= BATCH_SIZE { + let this_batch = std::mem::take(&mut batch); + let serialized = SerializedBatch::from_values(this_batch); + layer.put_batch(serialized, &ctx).await?; + } + } + if !batch.is_empty() { + let this_batch = std::mem::take(&mut batch); + let serialized = SerializedBatch::from_values(this_batch); + layer.put_batch(serialized, &ctx).await?; } layer.freeze(lsn + 1).await; diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index da0c11d9bf..7d404e50a5 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -126,10 +126,56 @@ fn main() -> anyhow::Result<()> { info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings"); info!(?conf.compact_level0_phase1_value_access, "starting with setting for compact_level0_phase1_value_access"); + // The tenants directory contains all the pageserver local disk state. + // Create if not exists and make sure all the contents are durable before proceeding. + // Ensuring durability eliminates a whole bug class where we come up after an unclean shutdown. + // After unclea shutdown, we don't know if all the filesystem content we can read via syscalls is actually durable or not. + // Examples for that: OOM kill, systemd killing us during shutdown, self abort due to unrecoverable IO error. let tenants_path = conf.tenants_path(); - if !tenants_path.exists() { - utils::crashsafe::create_dir_all(conf.tenants_path()) - .with_context(|| format!("Failed to create tenants root dir at '{tenants_path}'"))?; + { + let open = || { + nix::dir::Dir::open( + tenants_path.as_std_path(), + nix::fcntl::OFlag::O_DIRECTORY | nix::fcntl::OFlag::O_RDONLY, + nix::sys::stat::Mode::empty(), + ) + }; + let dirfd = match open() { + Ok(dirfd) => dirfd, + Err(e) => match e { + nix::errno::Errno::ENOENT => { + utils::crashsafe::create_dir_all(&tenants_path).with_context(|| { + format!("Failed to create tenants root dir at '{tenants_path}'") + })?; + open().context("open tenants dir after creating it")? + } + e => anyhow::bail!(e), + }, + }; + + let started = Instant::now(); + // Linux guarantees durability for syncfs. + // POSIX doesn't have syncfs, and further does not actually guarantee durability of sync(). + #[cfg(target_os = "linux")] + { + use std::os::fd::AsRawFd; + nix::unistd::syncfs(dirfd.as_raw_fd()).context("syncfs")?; + } + #[cfg(target_os = "macos")] + { + // macOS is not a production platform for Neon, don't even bother. + drop(dirfd); + } + #[cfg(not(any(target_os = "linux", target_os = "macos")))] + { + compile_error!("Unsupported OS"); + } + + let elapsed = started.elapsed(); + info!( + elapsed_ms = elapsed.as_millis(), + "made tenant directory contents durable" + ); } // Initialize up failpoints support diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 3ac5ac539f..0ebaf78840 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -50,7 +50,6 @@ pub mod defaults { DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_HTTP_LISTEN_PORT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_PG_LISTEN_PORT, }; - use pageserver_api::models::ImageCompressionAlgorithm; pub use storage_broker::DEFAULT_ENDPOINT as BROKER_DEFAULT_ENDPOINT; pub const DEFAULT_WAIT_LSN_TIMEOUT: &str = "300 s"; @@ -90,8 +89,7 @@ pub mod defaults { pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 128 * 1024; // 128 KiB - pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm = - ImageCompressionAlgorithm::Disabled; + pub const DEFAULT_IMAGE_COMPRESSION: &str = "zstd(1)"; pub const DEFAULT_VALIDATE_VECTORED_GET: bool = false; @@ -478,7 +476,7 @@ impl PageServerConfigBuilder { max_vectored_read_bytes: Set(MaxVectoredReadBytes( NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(), )), - image_compression: Set(DEFAULT_IMAGE_COMPRESSION), + image_compression: Set(DEFAULT_IMAGE_COMPRESSION.parse().unwrap()), ephemeral_bytes_per_memory_kb: Set(DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), l0_flush: Set(L0FlushConfig::default()), compact_level0_phase1_value_access: Set(CompactL0Phase1ValueAccess::default()), @@ -1065,7 +1063,7 @@ impl PageServerConf { NonZeroUsize::new(defaults::DEFAULT_MAX_VECTORED_READ_BYTES) .expect("Invalid default constant"), ), - image_compression: defaults::DEFAULT_IMAGE_COMPRESSION, + image_compression: defaults::DEFAULT_IMAGE_COMPRESSION.parse().unwrap(), ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), @@ -1305,7 +1303,7 @@ background_task_maximum_delay = '334 s' NonZeroUsize::new(defaults::DEFAULT_MAX_VECTORED_READ_BYTES) .expect("Invalid default constant") ), - image_compression: defaults::DEFAULT_IMAGE_COMPRESSION, + image_compression: defaults::DEFAULT_IMAGE_COMPRESSION.parse().unwrap(), ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), @@ -1378,7 +1376,7 @@ background_task_maximum_delay = '334 s' NonZeroUsize::new(defaults::DEFAULT_MAX_VECTORED_READ_BYTES) .expect("Invalid default constant") ), - image_compression: defaults::DEFAULT_IMAGE_COMPRESSION, + image_compression: defaults::DEFAULT_IMAGE_COMPRESSION.parse().unwrap(), ephemeral_bytes_per_memory_kb: defaults::DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB, l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index a4da8506d6..cbcc162b32 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -318,6 +318,24 @@ impl From for ApiError { } } +impl From for ApiError { + fn from(value: crate::tenant::TimelineArchivalError) -> Self { + use crate::tenant::TimelineArchivalError::*; + match value { + NotFound => ApiError::NotFound(anyhow::anyhow!("timeline not found").into()), + Timeout => ApiError::Timeout("hit pageserver internal timeout".into()), + HasUnarchivedChildren(children) => ApiError::PreconditionFailed( + format!( + "Cannot archive timeline which has non-archived child timelines: {children:?}" + ) + .into_boxed_str(), + ), + a @ AlreadyInProgress => ApiError::Conflict(a.to_string()), + Other(e) => ApiError::InternalServerError(e), + } + } +} + impl From for ApiError { fn from(value: crate::tenant::mgr::DeleteTimelineError) -> Self { use crate::tenant::mgr::DeleteTimelineError::*; @@ -405,6 +423,8 @@ async fn build_timeline_info_common( let current_logical_size = timeline.get_current_logical_size(logical_size_task_priority, ctx); let current_physical_size = Some(timeline.layer_size_sum().await); let state = timeline.current_state(); + // Report is_archived = false if the timeline is still loading + let is_archived = timeline.is_archived().unwrap_or(false); let remote_consistent_lsn_projected = timeline .get_remote_consistent_lsn_projected() .unwrap_or(Lsn(0)); @@ -445,6 +465,7 @@ async fn build_timeline_info_common( pg_version: timeline.pg_version, state, + is_archived, walreceiver_status, @@ -686,9 +707,7 @@ async fn timeline_archival_config_handler( tenant .apply_timeline_archival_config(timeline_id, request_data.state) - .await - .context("applying archival config") - .map_err(ApiError::InternalServerError)?; + .await?; Ok::<_, ApiError>(()) } .instrument(info_span!("timeline_archival_config", @@ -1706,11 +1725,6 @@ async fn timeline_compact_handler( flags |= CompactFlags::ForceImageLayerCreation; } if Some(true) == parse_query_param::<_, bool>(&request, "enhanced_gc_bottom_most_compaction")? { - if !cfg!(feature = "testing") { - return Err(ApiError::InternalServerError(anyhow!( - "enhanced_gc_bottom_most_compaction is only available in testing mode" - ))); - } flags |= CompactFlags::EnhancedGcBottomMostCompaction; } let wait_until_uploaded = @@ -2942,7 +2956,7 @@ pub fn make_router( ) .put( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/compact", - |r| testing_api_handler("run timeline compaction", r, timeline_compact_handler), + |r| api_handler(r, timeline_compact_handler), ) .put( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/checkpoint", diff --git a/pageserver/src/l0_flush.rs b/pageserver/src/l0_flush.rs index 10187f2ba3..313a7961a6 100644 --- a/pageserver/src/l0_flush.rs +++ b/pageserver/src/l0_flush.rs @@ -1,15 +1,10 @@ use std::{num::NonZeroUsize, sync::Arc}; -use crate::tenant::ephemeral_file; - #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize)] #[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)] pub enum L0FlushConfig { - PageCached, #[serde(rename_all = "snake_case")] - Direct { - max_concurrency: NonZeroUsize, - }, + Direct { max_concurrency: NonZeroUsize }, } impl Default for L0FlushConfig { @@ -25,14 +20,12 @@ impl Default for L0FlushConfig { pub struct L0FlushGlobalState(Arc); pub enum Inner { - PageCached, Direct { semaphore: tokio::sync::Semaphore }, } impl L0FlushGlobalState { pub fn new(config: L0FlushConfig) -> Self { match config { - L0FlushConfig::PageCached => Self(Arc::new(Inner::PageCached)), L0FlushConfig::Direct { max_concurrency } => { let semaphore = tokio::sync::Semaphore::new(max_concurrency.get()); Self(Arc::new(Inner::Direct { semaphore })) @@ -44,13 +37,3 @@ impl L0FlushGlobalState { &self.0 } } - -impl L0FlushConfig { - pub(crate) fn prewarm_on_write(&self) -> ephemeral_file::PrewarmPageCacheOnWrite { - use L0FlushConfig::*; - match self { - PageCached => ephemeral_file::PrewarmPageCacheOnWrite::Yes, - Direct { .. } => ephemeral_file::PrewarmPageCacheOnWrite::No, - } - } -} diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 5aee13cfc6..dbfc9f3544 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -49,7 +49,7 @@ use tracing::{info, info_span}; /// backwards-compatible changes to the metadata format. pub const STORAGE_FORMAT_VERSION: u16 = 3; -pub const DEFAULT_PG_VERSION: u32 = 15; +pub const DEFAULT_PG_VERSION: u32 = 16; // Magic constants used to identify different kinds of files pub const IMAGE_FILE_MAGIC: u16 = 0x5A60; @@ -88,6 +88,8 @@ pub async fn shutdown_pageserver( ) { use std::time::Duration; + let started_at = std::time::Instant::now(); + // If the orderly shutdown below takes too long, we still want to make // sure that all walredo processes are killed and wait()ed on by us, not systemd. // @@ -241,7 +243,10 @@ pub async fn shutdown_pageserver( walredo_extraordinary_shutdown_thread.join().unwrap(); info!("walredo_extraordinary_shutdown_thread done"); - info!("Shut down successfully completed"); + info!( + elapsed_ms = started_at.elapsed().as_millis(), + "Shut down successfully completed" + ); std::process::exit(exit_code); } diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index cd2cd43f27..c4011d593c 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1552,7 +1552,6 @@ pub(crate) static LIVE_CONNECTIONS: Lazy = Lazy::new(|| { #[derive(Clone, Copy, enum_map::Enum, IntoStaticStr)] pub(crate) enum ComputeCommandKind { PageStreamV2, - PageStream, Basebackup, Fullbackup, LeaseLsn, @@ -1803,6 +1802,23 @@ pub(crate) static SECONDARY_RESIDENT_PHYSICAL_SIZE: Lazy = Lazy::n .expect("failed to define a metric") }); +pub(crate) static NODE_UTILIZATION_SCORE: Lazy = Lazy::new(|| { + register_uint_gauge!( + "pageserver_utilization_score", + "The utilization score we report to the storage controller for scheduling, where 0 is empty, 1000000 is full, and anything above is considered overloaded", + ) + .expect("failed to define a metric") +}); + +pub(crate) static SECONDARY_HEATMAP_TOTAL_SIZE: Lazy = Lazy::new(|| { + register_uint_gauge_vec!( + "pageserver_secondary_heatmap_total_size", + "The total size in bytes of all layers in the most recently downloaded heatmap.", + &["tenant_id", "shard_id"] + ) + .expect("failed to define a metric") +}); + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum RemoteOpKind { Upload, @@ -1853,16 +1869,64 @@ pub(crate) static TENANT_TASK_EVENTS: Lazy = Lazy::new(|| { .expect("Failed to register tenant_task_events metric") }); -pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_GAUGE: Lazy = Lazy::new(|| { - register_int_counter_pair_vec!( - "pageserver_background_loop_semaphore_wait_start_count", - "Counter for background loop concurrency-limiting semaphore acquire calls started", - "pageserver_background_loop_semaphore_wait_finish_count", - "Counter for background loop concurrency-limiting semaphore acquire calls finished", - &["task"], - ) - .unwrap() -}); +pub struct BackgroundLoopSemaphoreMetrics { + counters: EnumMap, + durations: EnumMap, +} + +pub(crate) static BACKGROUND_LOOP_SEMAPHORE: Lazy = Lazy::new( + || { + let counters = register_int_counter_pair_vec!( + "pageserver_background_loop_semaphore_wait_start_count", + "Counter for background loop concurrency-limiting semaphore acquire calls started", + "pageserver_background_loop_semaphore_wait_finish_count", + "Counter for background loop concurrency-limiting semaphore acquire calls finished", + &["task"], + ) + .unwrap(); + + let durations = register_counter_vec!( + "pageserver_background_loop_semaphore_wait_duration_seconds", + "Sum of wall clock time spent waiting on the background loop concurrency-limiting semaphore acquire calls", + &["task"], + ) + .unwrap(); + + BackgroundLoopSemaphoreMetrics { + counters: enum_map::EnumMap::from_array(std::array::from_fn(|i| { + let kind = ::from_usize(i); + counters.with_label_values(&[kind.into()]) + })), + durations: enum_map::EnumMap::from_array(std::array::from_fn(|i| { + let kind = ::from_usize(i); + durations.with_label_values(&[kind.into()]) + })), + } + }, +); + +impl BackgroundLoopSemaphoreMetrics { + pub(crate) fn measure_acquisition(&self, task: BackgroundLoopKind) -> impl Drop + '_ { + struct Record<'a> { + metrics: &'a BackgroundLoopSemaphoreMetrics, + task: BackgroundLoopKind, + _counter_guard: metrics::IntCounterPairGuard, + start: Instant, + } + impl Drop for Record<'_> { + fn drop(&mut self) { + let elapsed = self.start.elapsed().as_secs_f64(); + self.metrics.durations[self.task].inc_by(elapsed); + } + } + Record { + metrics: self, + task, + _counter_guard: self.counters[task].guard(), + start: Instant::now(), + } + } +} pub(crate) static BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT: Lazy = Lazy::new(|| { register_int_counter_vec!( @@ -2544,6 +2608,7 @@ use std::time::{Duration, Instant}; use crate::context::{PageContentKind, RequestContext}; use crate::task_mgr::TaskKind; use crate::tenant::mgr::TenantSlot; +use crate::tenant::tasks::BackgroundLoopKind; /// Maintain a per timeline gauge in addition to the global gauge. pub(crate) struct PerTimelineRemotePhysicalSizeGauge { diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 81294291a9..cb1ab70147 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -557,7 +557,7 @@ impl PageServerHandler { pgb: &mut PostgresBackend, tenant_id: TenantId, timeline_id: TimelineId, - protocol_version: PagestreamProtocolVersion, + _protocol_version: PagestreamProtocolVersion, ctx: RequestContext, ) -> Result<(), QueryError> where @@ -601,8 +601,7 @@ impl PageServerHandler { fail::fail_point!("ps::handle-pagerequest-message"); // parse request - let neon_fe_msg = - PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; + let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader())?; // invoke handler function let (handler_result, span) = match neon_fe_msg { @@ -1275,35 +1274,6 @@ where ctx, ) .await?; - } else if let Some(params) = parts.strip_prefix(&["pagestream"]) { - if params.len() != 2 { - return Err(QueryError::Other(anyhow::anyhow!( - "invalid param number for pagestream command" - ))); - } - let tenant_id = TenantId::from_str(params[0]) - .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; - let timeline_id = TimelineId::from_str(params[1]) - .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; - - tracing::Span::current() - .record("tenant_id", field::display(tenant_id)) - .record("timeline_id", field::display(timeline_id)); - - self.check_permission(Some(tenant_id))?; - - COMPUTE_COMMANDS_COUNTERS - .for_command(ComputeCommandKind::PageStream) - .inc(); - - self.handle_pagerequests( - pgb, - tenant_id, - timeline_id, - PagestreamProtocolVersion::V1, - ctx, - ) - .await?; } else if let Some(params) = parts.strip_prefix(&["basebackup"]) { if params.len() < 2 { return Err(QueryError::Other(anyhow::anyhow!( diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 4f7eb1a00c..b7110d69b6 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -15,12 +15,11 @@ use crate::{aux_file, repository::*}; use anyhow::{ensure, Context}; use bytes::{Buf, Bytes, BytesMut}; use enum_map::Enum; -use itertools::Itertools; use pageserver_api::key::{ dbdir_key_range, rel_block_to_key, rel_dir_to_key, rel_key_range, rel_size_to_key, relmap_file_key, repl_origin_key, repl_origin_key_range, slru_block_to_key, slru_dir_to_key, slru_segment_key_range, slru_segment_size_to_key, twophase_file_key, twophase_key_range, - AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY, + CompactKey, AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY, }; use pageserver_api::keyspace::SparseKeySpace; use pageserver_api::models::AuxFilePolicy; @@ -37,7 +36,6 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, info, trace, warn}; use utils::bin_ser::DeserializeError; use utils::pausable_failpoint; -use utils::vec_map::{VecMap, VecMapOrdering}; use utils::{bin_ser::BeSer, lsn::Lsn}; /// Max delta records appended to the AUX_FILES_KEY (for aux v1). The write path will write a full image once this threshold is reached. @@ -174,6 +172,7 @@ impl Timeline { pending_deletions: Vec::new(), pending_nblocks: 0, pending_directory_entries: Vec::new(), + pending_bytes: 0, lsn, } } @@ -727,7 +726,17 @@ impl Timeline { ) -> Result, PageReconstructError> { let current_policy = self.last_aux_file_policy.load(); match current_policy { - Some(AuxFilePolicy::V1) | None => self.list_aux_files_v1(lsn, ctx).await, + Some(AuxFilePolicy::V1) => { + warn!("this timeline is using deprecated aux file policy V1 (policy=V1)"); + self.list_aux_files_v1(lsn, ctx).await + } + None => { + let res = self.list_aux_files_v1(lsn, ctx).await?; + if !res.is_empty() { + warn!("this timeline is using deprecated aux file policy V1 (policy=None)"); + } + Ok(res) + } Some(AuxFilePolicy::V2) => self.list_aux_files_v2(lsn, ctx).await, Some(AuxFilePolicy::CrossValidation) => { let v1_result = self.list_aux_files_v1(lsn, ctx).await; @@ -1022,21 +1031,33 @@ pub struct DatadirModification<'a> { // The put-functions add the modifications here, and they are flushed to the // underlying key-value store by the 'finish' function. pending_lsns: Vec, - pending_updates: HashMap>, + pending_updates: HashMap>, pending_deletions: Vec<(Range, Lsn)>, pending_nblocks: i64, /// For special "directory" keys that store key-value maps, track the size of the map /// if it was updated in this modification. pending_directory_entries: Vec<(DirectoryKind, usize)>, + + /// An **approximation** of how large our EphemeralFile write will be when committed. + pending_bytes: usize, } impl<'a> DatadirModification<'a> { + // When a DatadirModification is committed, we do a monolithic serialization of all its contents. WAL records can + // contain multiple pages, so the pageserver's record-based batch size isn't sufficient to bound this allocation: we + // additionally specify a limit on how much payload a DatadirModification may contain before it should be committed. + pub(crate) const MAX_PENDING_BYTES: usize = 8 * 1024 * 1024; + /// Get the current lsn pub(crate) fn get_lsn(&self) -> Lsn { self.lsn } + pub(crate) fn approx_pending_bytes(&self) -> usize { + self.pending_bytes + } + /// Set the current lsn pub(crate) fn set_lsn(&mut self, lsn: Lsn) -> anyhow::Result<()> { ensure!( @@ -1576,6 +1597,7 @@ impl<'a> DatadirModification<'a> { if aux_files_key_v1.is_empty() { None } else { + warn!("this timeline is using deprecated aux file policy V1"); self.tline.do_switch_aux_policy(AuxFilePolicy::V1)?; Some(AuxFilePolicy::V1) } @@ -1769,21 +1791,25 @@ impl<'a> DatadirModification<'a> { // Flush relation and SLRU data blocks, keep metadata. let mut retained_pending_updates = HashMap::<_, Vec<_>>::new(); for (key, values) in self.pending_updates.drain() { - for (lsn, value) in values { + let mut write_batch = Vec::new(); + for (lsn, value_ser_size, value) in values { if key.is_rel_block_key() || key.is_slru_block_key() { // This bails out on first error without modifying pending_updates. // That's Ok, cf this function's doc comment. - writer.put(key, lsn, &value, ctx).await?; + write_batch.push((key.to_compact(), lsn, value_ser_size, value)); } else { - retained_pending_updates - .entry(key) - .or_default() - .push((lsn, value)); + retained_pending_updates.entry(key).or_default().push(( + lsn, + value_ser_size, + value, + )); } } + writer.put_batch(write_batch, ctx).await?; } self.pending_updates = retained_pending_updates; + self.pending_bytes = 0; if pending_nblocks != 0 { writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ)); @@ -1809,17 +1835,20 @@ impl<'a> DatadirModification<'a> { self.pending_nblocks = 0; if !self.pending_updates.is_empty() { - // The put_batch call below expects expects the inputs to be sorted by Lsn, - // so we do that first. - let lsn_ordered_batch: VecMap = VecMap::from_iter( - self.pending_updates - .drain() - .map(|(key, vals)| vals.into_iter().map(move |(lsn, val)| (lsn, (key, val)))) - .kmerge_by(|lhs, rhs| lhs.0 < rhs.0), - VecMapOrdering::GreaterOrEqual, - ); + // Ordering: the items in this batch do not need to be in any global order, but values for + // a particular Key must be in Lsn order relative to one another. InMemoryLayer relies on + // this to do efficient updates to its index. + let batch: Vec<(CompactKey, Lsn, usize, Value)> = self + .pending_updates + .drain() + .flat_map(|(key, values)| { + values.into_iter().map(move |(lsn, val_ser_size, value)| { + (key.to_compact(), lsn, val_ser_size, value) + }) + }) + .collect::>(); - writer.put_batch(lsn_ordered_batch, ctx).await?; + writer.put_batch(batch, ctx).await?; } if !self.pending_deletions.is_empty() { @@ -1844,6 +1873,8 @@ impl<'a> DatadirModification<'a> { writer.update_directory_entries_count(kind, count as u64); } + self.pending_bytes = 0; + Ok(()) } @@ -1860,7 +1891,7 @@ impl<'a> DatadirModification<'a> { // Note: we don't check pending_deletions. It is an error to request a // value that has been removed, deletion only avoids leaking storage. if let Some(values) = self.pending_updates.get(&key) { - if let Some((_, value)) = values.last() { + if let Some((_, _, value)) = values.last() { return if let Value::Image(img) = value { Ok(img.clone()) } else { @@ -1888,13 +1919,17 @@ impl<'a> DatadirModification<'a> { fn put(&mut self, key: Key, val: Value) { let values = self.pending_updates.entry(key).or_default(); // Replace the previous value if it exists at the same lsn - if let Some((last_lsn, last_value)) = values.last_mut() { + if let Some((last_lsn, last_value_ser_size, last_value)) = values.last_mut() { if *last_lsn == self.lsn { + *last_value_ser_size = val.serialized_size().unwrap() as usize; *last_value = val; return; } } - values.push((self.lsn, val)); + + let val_serialized_size = val.serialized_size().unwrap() as usize; + self.pending_bytes += val_serialized_size; + values.push((self.lsn, val_serialized_size, val)); } fn delete(&mut self, key_range: Range) { @@ -2024,7 +2059,7 @@ mod tests { let (tenant, ctx) = harness.load().await; let tline = tenant - .create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx) + .create_empty_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx) .await?; let tline = tline.raw_timeline().unwrap(); diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 5cd78874c1..ed9e001fd2 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -393,7 +393,7 @@ struct PageServerTask { /// Tasks may optionally be launched for a particular tenant/timeline, enabling /// later cancelling tasks for that tenant/timeline in [`shutdown_tasks`] - tenant_shard_id: Option, + tenant_shard_id: TenantShardId, timeline_id: Option, mutable: Mutex, @@ -405,7 +405,7 @@ struct PageServerTask { pub fn spawn( runtime: &tokio::runtime::Handle, kind: TaskKind, - tenant_shard_id: Option, + tenant_shard_id: TenantShardId, timeline_id: Option, name: &str, future: F, @@ -550,7 +550,7 @@ pub async fn shutdown_tasks( let tasks = TASKS.lock().unwrap(); for task in tasks.values() { if (kind.is_none() || Some(task.kind) == kind) - && (tenant_shard_id.is_none() || task.tenant_shard_id == tenant_shard_id) + && (tenant_shard_id.is_none() || Some(task.tenant_shard_id) == tenant_shard_id) && (timeline_id.is_none() || task.timeline_id == timeline_id) { task.cancel.cancel(); @@ -573,13 +573,8 @@ pub async fn shutdown_tasks( }; if let Some(mut join_handle) = join_handle { if log_all { - if tenant_shard_id.is_none() { - // there are quite few of these - info!(name = task.name, kind = ?task_kind, "stopping global task"); - } else { - // warn to catch these in tests; there shouldn't be any - warn!(name = task.name, tenant_shard_id = ?tenant_shard_id, timeline_id = ?timeline_id, kind = ?task_kind, "stopping left-over"); - } + // warn to catch these in tests; there shouldn't be any + warn!(name = task.name, tenant_shard_id = ?tenant_shard_id, timeline_id = ?timeline_id, kind = ?task_kind, "stopping left-over"); } if tokio::time::timeout(std::time::Duration::from_secs(1), &mut join_handle) .await diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 8ab8d08ce1..0364d521b6 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -501,6 +501,38 @@ impl Debug for DeleteTimelineError { } } +#[derive(thiserror::Error)] +pub enum TimelineArchivalError { + #[error("NotFound")] + NotFound, + + #[error("Timeout")] + Timeout, + + #[error("HasUnarchivedChildren")] + HasUnarchivedChildren(Vec), + + #[error("Timeline archival is already in progress")] + AlreadyInProgress, + + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl Debug for TimelineArchivalError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotFound => write!(f, "NotFound"), + Self::Timeout => write!(f, "Timeout"), + Self::HasUnarchivedChildren(c) => { + f.debug_tuple("HasUnarchivedChildren").field(c).finish() + } + Self::AlreadyInProgress => f.debug_tuple("AlreadyInProgress").finish(), + Self::Other(e) => f.debug_tuple("Other").field(e).finish(), + } + } +} + pub enum SetStoppingError { AlreadyStopping(completion::Barrier), Broken, @@ -798,7 +830,7 @@ impl Tenant { task_mgr::spawn( &tokio::runtime::Handle::current(), TaskKind::Attach, - Some(tenant_shard_id), + tenant_shard_id, None, "attach tenant", async move { @@ -1326,24 +1358,50 @@ impl Tenant { &self, timeline_id: TimelineId, state: TimelineArchivalState, - ) -> anyhow::Result<()> { - let timeline = self - .get_timeline(timeline_id, false) - .context("Cannot apply timeline archival config to inexistent timeline")?; + ) -> Result<(), TimelineArchivalError> { + info!("setting timeline archival config"); + let timeline = { + let timelines = self.timelines.lock().unwrap(); + + let timeline = match timelines.get(&timeline_id) { + Some(t) => t, + None => return Err(TimelineArchivalError::NotFound), + }; + + // Ensure that there are no non-archived child timelines + let children: Vec = timelines + .iter() + .filter_map(|(id, entry)| { + if entry.get_ancestor_timeline_id() != Some(timeline_id) { + return None; + } + if entry.is_archived() == Some(true) { + return None; + } + Some(*id) + }) + .collect(); + + if !children.is_empty() && state == TimelineArchivalState::Archived { + return Err(TimelineArchivalError::HasUnarchivedChildren(children)); + } + Arc::clone(timeline) + }; let upload_needed = timeline .remote_client .schedule_index_upload_for_timeline_archival_state(state)?; if upload_needed { + info!("Uploading new state"); const MAX_WAIT: Duration = Duration::from_secs(10); let Ok(v) = tokio::time::timeout(MAX_WAIT, timeline.remote_client.wait_completion()).await else { tracing::warn!("reached timeout for waiting on upload queue"); - bail!("reached timeout for upload queue flush"); + return Err(TimelineArchivalError::Timeout); }; - v?; + v.map_err(|e| TimelineArchivalError::Other(anyhow::anyhow!(e)))?; } Ok(()) } @@ -3741,13 +3799,21 @@ impl Tenant { /// less than this (via eviction and on-demand downloads), but this function enables /// the Tenant to advertise how much storage it would prefer to have to provide fast I/O /// by keeping important things on local disk. + /// + /// This is a heuristic, not a guarantee: tenants that are long-idle will actually use less + /// than they report here, due to layer eviction. Tenants with many active branches may + /// actually use more than they report here. pub(crate) fn local_storage_wanted(&self) -> u64 { - let mut wanted = 0; let timelines = self.timelines.lock().unwrap(); - for timeline in timelines.values() { - wanted += timeline.metrics.visible_physical_size_gauge.get(); - } - wanted + + // Heuristic: we use the max() of the timelines' visible sizes, rather than the sum. This + // reflects the observation that on tenants with multiple large branches, typically only one + // of them is used actively enough to occupy space on disk. + timelines + .values() + .map(|t| t.metrics.visible_physical_size_gauge.get()) + .max() + .unwrap_or(0) } } @@ -5932,10 +5998,10 @@ mod tests { .await .unwrap(); - // the default aux file policy to switch is v1 if not set by the admins + // the default aux file policy to switch is v2 if not set by the admins assert_eq!( harness.tenant_conf.switch_aux_file_policy, - AuxFilePolicy::V1 + AuxFilePolicy::default_tenant_config() ); let (tenant, ctx) = harness.load().await; @@ -5979,8 +6045,8 @@ mod tests { ); assert_eq!( tline.last_aux_file_policy.load(), - Some(AuxFilePolicy::V1), - "aux file is written with switch_aux_file_policy unset (which is v1), so we should keep v1" + Some(AuxFilePolicy::V2), + "aux file is written with switch_aux_file_policy unset (which is v2), so we should use v2 there" ); // we can read everything from the storage @@ -6002,8 +6068,8 @@ mod tests { assert_eq!( tline.last_aux_file_policy.load(), - Some(AuxFilePolicy::V1), - "keep v1 storage format when new files are written" + Some(AuxFilePolicy::V2), + "keep v2 storage format when new files are written" ); let files = tline.list_aux_files(lsn, &ctx).await.unwrap(); @@ -6019,7 +6085,7 @@ mod tests { // child copies the last flag even if that is not on remote storage yet assert_eq!(child.get_switch_aux_file_policy(), AuxFilePolicy::V2); - assert_eq!(child.last_aux_file_policy.load(), Some(AuxFilePolicy::V1)); + assert_eq!(child.last_aux_file_policy.load(), Some(AuxFilePolicy::V2)); let files = child.list_aux_files(lsn, &ctx).await.unwrap(); assert_eq!(files.get("pg_logical/mappings/test1"), None); @@ -7005,18 +7071,14 @@ mod tests { vec![ // Image layer at GC horizon PersistentLayerKey { - key_range: { - let mut key = Key::MAX; - key.field6 -= 1; - Key::MIN..key - }, + key_range: Key::MIN..Key::NON_L0_MAX, lsn_range: Lsn(0x30)..Lsn(0x31), is_delta: false }, - // The delta layer that is cut in the middle + // The delta layer covers the full range (with the layer key hack to avoid being recognized as L0) PersistentLayerKey { - key_range: get_key(3)..get_key(4), - lsn_range: Lsn(0x30)..Lsn(0x41), + key_range: Key::MIN..Key::NON_L0_MAX, + lsn_range: Lsn(0x30)..Lsn(0x48), is_delta: true }, // The delta3 layer that should not be picked for the compaction @@ -7996,6 +8058,214 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_simple_bottom_most_compaction_with_retain_lsns_single_key() -> anyhow::Result<()> + { + let harness = + TenantHarness::create("test_simple_bottom_most_compaction_with_retain_lsns_single_key") + .await?; + let (tenant, ctx) = harness.load().await; + + fn get_key(id: u32) -> Key { + // using aux key here b/c they are guaranteed to be inside `collect_keyspace`. + let mut key = Key::from_hex("620000000033333333444444445500000000").unwrap(); + key.field6 = id; + key + } + + let img_layer = (0..10) + .map(|id| (get_key(id), Bytes::from(format!("value {id}@0x10")))) + .collect_vec(); + + let delta1 = vec![ + ( + get_key(1), + Lsn(0x20), + Value::WalRecord(NeonWalRecord::wal_append("@0x20")), + ), + ( + get_key(1), + Lsn(0x28), + Value::WalRecord(NeonWalRecord::wal_append("@0x28")), + ), + ]; + let delta2 = vec![ + ( + get_key(1), + Lsn(0x30), + Value::WalRecord(NeonWalRecord::wal_append("@0x30")), + ), + ( + get_key(1), + Lsn(0x38), + Value::WalRecord(NeonWalRecord::wal_append("@0x38")), + ), + ]; + let delta3 = vec![ + ( + get_key(8), + Lsn(0x48), + Value::WalRecord(NeonWalRecord::wal_append("@0x48")), + ), + ( + get_key(9), + Lsn(0x48), + Value::WalRecord(NeonWalRecord::wal_append("@0x48")), + ), + ]; + + let tline = tenant + .create_test_timeline_with_layers( + TIMELINE_ID, + Lsn(0x10), + DEFAULT_PG_VERSION, + &ctx, + vec![ + // delta1 and delta 2 only contain a single key but multiple updates + DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x10)..Lsn(0x30), delta1), + DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x30)..Lsn(0x50), delta2), + DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x10)..Lsn(0x50), delta3), + ], // delta layers + vec![(Lsn(0x10), img_layer)], // image layers + Lsn(0x50), + ) + .await?; + { + // Update GC info + let mut guard = tline.gc_info.write().unwrap(); + *guard = GcInfo { + retain_lsns: vec![ + (Lsn(0x10), tline.timeline_id), + (Lsn(0x20), tline.timeline_id), + ], + cutoffs: GcCutoffs { + time: Lsn(0x30), + space: Lsn(0x30), + }, + leases: Default::default(), + within_ancestor_pitr: false, + }; + } + + let expected_result = [ + Bytes::from_static(b"value 0@0x10"), + Bytes::from_static(b"value 1@0x10@0x20@0x28@0x30@0x38"), + Bytes::from_static(b"value 2@0x10"), + Bytes::from_static(b"value 3@0x10"), + Bytes::from_static(b"value 4@0x10"), + Bytes::from_static(b"value 5@0x10"), + Bytes::from_static(b"value 6@0x10"), + Bytes::from_static(b"value 7@0x10"), + Bytes::from_static(b"value 8@0x10@0x48"), + Bytes::from_static(b"value 9@0x10@0x48"), + ]; + + let expected_result_at_gc_horizon = [ + Bytes::from_static(b"value 0@0x10"), + Bytes::from_static(b"value 1@0x10@0x20@0x28@0x30"), + Bytes::from_static(b"value 2@0x10"), + Bytes::from_static(b"value 3@0x10"), + Bytes::from_static(b"value 4@0x10"), + Bytes::from_static(b"value 5@0x10"), + Bytes::from_static(b"value 6@0x10"), + Bytes::from_static(b"value 7@0x10"), + Bytes::from_static(b"value 8@0x10"), + Bytes::from_static(b"value 9@0x10"), + ]; + + let expected_result_at_lsn_20 = [ + Bytes::from_static(b"value 0@0x10"), + Bytes::from_static(b"value 1@0x10@0x20"), + Bytes::from_static(b"value 2@0x10"), + Bytes::from_static(b"value 3@0x10"), + Bytes::from_static(b"value 4@0x10"), + Bytes::from_static(b"value 5@0x10"), + Bytes::from_static(b"value 6@0x10"), + Bytes::from_static(b"value 7@0x10"), + Bytes::from_static(b"value 8@0x10"), + Bytes::from_static(b"value 9@0x10"), + ]; + + let expected_result_at_lsn_10 = [ + Bytes::from_static(b"value 0@0x10"), + Bytes::from_static(b"value 1@0x10"), + Bytes::from_static(b"value 2@0x10"), + Bytes::from_static(b"value 3@0x10"), + Bytes::from_static(b"value 4@0x10"), + Bytes::from_static(b"value 5@0x10"), + Bytes::from_static(b"value 6@0x10"), + Bytes::from_static(b"value 7@0x10"), + Bytes::from_static(b"value 8@0x10"), + Bytes::from_static(b"value 9@0x10"), + ]; + + let verify_result = || async { + let gc_horizon = { + let gc_info = tline.gc_info.read().unwrap(); + gc_info.cutoffs.time + }; + for idx in 0..10 { + assert_eq!( + tline + .get(get_key(idx as u32), Lsn(0x50), &ctx) + .await + .unwrap(), + &expected_result[idx] + ); + assert_eq!( + tline + .get(get_key(idx as u32), gc_horizon, &ctx) + .await + .unwrap(), + &expected_result_at_gc_horizon[idx] + ); + assert_eq!( + tline + .get(get_key(idx as u32), Lsn(0x20), &ctx) + .await + .unwrap(), + &expected_result_at_lsn_20[idx] + ); + assert_eq!( + tline + .get(get_key(idx as u32), Lsn(0x10), &ctx) + .await + .unwrap(), + &expected_result_at_lsn_10[idx] + ); + } + }; + + verify_result().await; + + let cancel = CancellationToken::new(); + let mut dryrun_flags = EnumSet::new(); + dryrun_flags.insert(CompactFlags::DryRun); + + tline + .compact_with_gc(&cancel, dryrun_flags, &ctx) + .await + .unwrap(); + // We expect layer map to be the same b/c the dry run flag, but we don't know whether there will be other background jobs + // cleaning things up, and therefore, we don't do sanity checks on the layer map during unit tests. + verify_result().await; + + tline + .compact_with_gc(&cancel, EnumSet::new(), &ctx) + .await + .unwrap(); + verify_result().await; + + // compact again + tline + .compact_with_gc(&cancel, EnumSet::new(), &ctx) + .await + .unwrap(); + verify_result().await; + + Ok(()) + } + #[tokio::test] async fn test_simple_bottom_most_compaction_on_branch() -> anyhow::Result<()> { let harness = TenantHarness::create("test_simple_bottom_most_compaction_on_branch").await?; diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 770f3ca5f0..44f0fc7ab1 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -21,7 +21,6 @@ pub struct EphemeralFile { } mod page_caching; -pub(crate) use page_caching::PrewarmOnWrite as PrewarmPageCacheOnWrite; mod zero_padded_read_write; impl EphemeralFile { @@ -52,12 +51,10 @@ impl EphemeralFile { ) .await?; - let prewarm = conf.l0_flush.prewarm_on_write(); - Ok(EphemeralFile { _tenant_shard_id: tenant_shard_id, _timeline_id: timeline_id, - rw: page_caching::RW::new(file, prewarm, gate_guard), + rw: page_caching::RW::new(file, gate_guard), }) } @@ -82,6 +79,8 @@ impl EphemeralFile { self.rw.read_blk(blknum, ctx).await } + #[cfg(test)] + // This is a test helper: outside of tests, we are always written to via a pre-serialized batch. pub(crate) async fn write_blob( &mut self, srcbuf: &[u8], @@ -89,17 +88,30 @@ impl EphemeralFile { ) -> Result { let pos = self.rw.bytes_written(); - // Write the length field - if srcbuf.len() < 0x80 { - // short one-byte length header - let len_buf = [srcbuf.len() as u8]; + let mut len_bytes = std::io::Cursor::new(Vec::new()); + crate::tenant::storage_layer::inmemory_layer::SerializedBatch::write_blob_length( + srcbuf.len(), + &mut len_bytes, + ); + let len_bytes = len_bytes.into_inner(); - self.rw.write_all_borrowed(&len_buf, ctx).await?; - } else { - let mut len_buf = u32::to_be_bytes(srcbuf.len() as u32); - len_buf[0] |= 0x80; - self.rw.write_all_borrowed(&len_buf, ctx).await?; - } + // Write the length field + self.rw.write_all_borrowed(&len_bytes, ctx).await?; + + // Write the payload + self.rw.write_all_borrowed(srcbuf, ctx).await?; + + Ok(pos) + } + + /// Returns the offset at which the first byte of the input was written, for use + /// in constructing indices over the written value. + pub(crate) async fn write_raw( + &mut self, + srcbuf: &[u8], + ctx: &RequestContext, + ) -> Result { + let pos = self.rw.bytes_written(); // Write the payload self.rw.write_all_borrowed(srcbuf, ctx).await?; diff --git a/pageserver/src/tenant/ephemeral_file/page_caching.rs b/pageserver/src/tenant/ephemeral_file/page_caching.rs index 7355b3b5a3..48926354f1 100644 --- a/pageserver/src/tenant/ephemeral_file/page_caching.rs +++ b/pageserver/src/tenant/ephemeral_file/page_caching.rs @@ -1,15 +1,15 @@ //! Wrapper around [`super::zero_padded_read_write::RW`] that uses the //! [`crate::page_cache`] to serve reads that need to go to the underlying [`VirtualFile`]. +//! +//! Subject to removal in use crate::context::RequestContext; use crate::page_cache::{self, PAGE_SZ}; use crate::tenant::block_io::BlockLease; -use crate::virtual_file::owned_buffers_io::io_buf_ext::FullSlice; +use crate::virtual_file::owned_buffers_io::util::size_tracking_writer; use crate::virtual_file::VirtualFile; -use once_cell::sync::Lazy; -use std::io::{self, ErrorKind}; -use std::ops::{Deref, Range}; +use std::io::{self}; use tokio_epoll_uring::BoundedBuf; use tracing::*; @@ -18,33 +18,17 @@ use super::zero_padded_read_write; /// See module-level comment. pub struct RW { page_cache_file_id: page_cache::FileId, - rw: super::zero_padded_read_write::RW, + rw: super::zero_padded_read_write::RW>, /// Gate guard is held on as long as we need to do operations in the path (delete on drop). _gate_guard: utils::sync::gate::GateGuard, } -/// When we flush a block to the underlying [`crate::virtual_file::VirtualFile`], -/// should we pre-warm the [`crate::page_cache`] with the contents? -#[derive(Clone, Copy)] -pub enum PrewarmOnWrite { - Yes, - No, -} - impl RW { - pub fn new( - file: VirtualFile, - prewarm_on_write: PrewarmOnWrite, - _gate_guard: utils::sync::gate::GateGuard, - ) -> Self { + pub fn new(file: VirtualFile, _gate_guard: utils::sync::gate::GateGuard) -> Self { let page_cache_file_id = page_cache::next_file_id(); Self { page_cache_file_id, - rw: super::zero_padded_read_write::RW::new(PreWarmingWriter::new( - page_cache_file_id, - file, - prewarm_on_write, - )), + rw: super::zero_padded_read_write::RW::new(size_tracking_writer::Writer::new(file)), _gate_guard, } } @@ -84,10 +68,10 @@ impl RW { let vec = Vec::with_capacity(size); // read from disk what we've already flushed - let writer = self.rw.as_writer(); - let flushed_range = writer.written_range(); - let mut vec = writer - .file + let file_size_tracking_writer = self.rw.as_writer(); + let flushed_range = 0..usize::try_from(file_size_tracking_writer.bytes_written()).unwrap(); + let mut vec = file_size_tracking_writer + .as_inner() .read_exact_at( vec.slice(0..(flushed_range.end - flushed_range.start)), u64::try_from(flushed_range.start).unwrap(), @@ -122,7 +106,7 @@ impl RW { format!( "ephemeral file: read immutable page #{}: {}: {:#}", blknum, - self.rw.as_writer().file.path, + self.rw.as_writer().as_inner().path, e, ), ) @@ -132,7 +116,7 @@ impl RW { } page_cache::ReadBufResult::NotFound(write_guard) => { let write_guard = writer - .file + .as_inner() .read_exact_at_page(write_guard, blknum as u64 * PAGE_SZ as u64, ctx) .await?; let read_guard = write_guard.mark_valid(); @@ -154,137 +138,16 @@ impl Drop for RW { // unlink the file // we are clear to do this, because we have entered a gate - let res = std::fs::remove_file(&self.rw.as_writer().file.path); + let path = &self.rw.as_writer().as_inner().path; + let res = std::fs::remove_file(path); if let Err(e) = res { if e.kind() != std::io::ErrorKind::NotFound { // just never log the not found errors, we cannot do anything for them; on detach // the tenant directory is already gone. // // not found files might also be related to https://github.com/neondatabase/neon/issues/2442 - error!( - "could not remove ephemeral file '{}': {}", - self.rw.as_writer().file.path, - e - ); + error!("could not remove ephemeral file '{path}': {e}"); } } } } - -struct PreWarmingWriter { - prewarm_on_write: PrewarmOnWrite, - nwritten_blocks: u32, - page_cache_file_id: page_cache::FileId, - file: VirtualFile, -} - -impl PreWarmingWriter { - fn new( - page_cache_file_id: page_cache::FileId, - file: VirtualFile, - prewarm_on_write: PrewarmOnWrite, - ) -> Self { - Self { - prewarm_on_write, - nwritten_blocks: 0, - page_cache_file_id, - file, - } - } - - /// Return the byte range within `file` that has been written though `write_all`. - /// - /// The returned range would be invalidated by another `write_all`. To prevent that, we capture `&_`. - fn written_range(&self) -> (impl Deref> + '_) { - let nwritten_blocks = usize::try_from(self.nwritten_blocks).unwrap(); - struct Wrapper(Range); - impl Deref for Wrapper { - type Target = Range; - fn deref(&self) -> &Range { - &self.0 - } - } - Wrapper(0..nwritten_blocks * PAGE_SZ) - } -} - -impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmingWriter { - async fn write_all( - &mut self, - buf: FullSlice, - ctx: &RequestContext, - ) -> std::io::Result<(usize, FullSlice)> { - let buflen = buf.len(); - assert_eq!( - buflen % PAGE_SZ, - 0, - "{buflen} ; we know TAIL_SZ is a PAGE_SZ multiple, and write_buffered_borrowed is used" - ); - - // Do the IO. - let buf = match self.file.write_all(buf, ctx).await { - (buf, Ok(nwritten)) => { - assert_eq!(nwritten, buflen); - buf - } - (_, Err(e)) => { - return Err(std::io::Error::new( - ErrorKind::Other, - // order error before path because path is long and error is short - format!( - "ephemeral_file: write_blob: write-back tail self.nwritten_blocks={}, buflen={}, {:#}: {}", - self.nwritten_blocks, buflen, e, self.file.path, - ), - )); - } - }; - - let nblocks = buflen / PAGE_SZ; - let nblocks32 = u32::try_from(nblocks).unwrap(); - - if matches!(self.prewarm_on_write, PrewarmOnWrite::Yes) { - // Pre-warm page cache with the contents. - // At least in isolated bulk ingest benchmarks (test_bulk_insert.py), the pre-warming - // benefits the code that writes InMemoryLayer=>L0 layers. - - let cache = page_cache::get(); - static CTX: Lazy = Lazy::new(|| { - RequestContext::new( - crate::task_mgr::TaskKind::EphemeralFilePreWarmPageCache, - crate::context::DownloadBehavior::Error, - ) - }); - for blknum_in_buffer in 0..nblocks { - let blk_in_buffer = - &buf[blknum_in_buffer * PAGE_SZ..(blknum_in_buffer + 1) * PAGE_SZ]; - let blknum = self - .nwritten_blocks - .checked_add(blknum_in_buffer as u32) - .unwrap(); - match cache - .read_immutable_buf(self.page_cache_file_id, blknum, &CTX) - .await - { - Err(e) => { - error!("ephemeral_file write_blob failed to get immutable buf to pre-warm page cache: {e:?}"); - // fail gracefully, it's not the end of the world if we can't pre-warm the cache here - } - Ok(v) => match v { - page_cache::ReadBufResult::Found(_guard) => { - // This function takes &mut self, so, it shouldn't be possible to reach this point. - unreachable!("we just wrote block {blknum} to the VirtualFile, which is owned by Self, \ - and this function takes &mut self, so, no concurrent read_blk is possible"); - } - page_cache::ReadBufResult::NotFound(mut write_guard) => { - write_guard.copy_from_slice(blk_in_buffer); - let _ = write_guard.mark_valid(); - } - }, - } - } - } - - self.nwritten_blocks = self.nwritten_blocks.checked_add(nblocks32).unwrap(); - Ok((buflen, buf)) - } -} diff --git a/pageserver/src/tenant/layer_map.rs b/pageserver/src/tenant/layer_map.rs index 844f117ea2..707233b003 100644 --- a/pageserver/src/tenant/layer_map.rs +++ b/pageserver/src/tenant/layer_map.rs @@ -464,7 +464,7 @@ impl LayerMap { pub(self) fn insert_historic_noflush(&mut self, layer_desc: PersistentLayerDesc) { // TODO: See #3869, resulting #4088, attempted fix and repro #4094 - if Self::is_l0(&layer_desc.key_range) { + if Self::is_l0(&layer_desc.key_range, layer_desc.is_delta) { self.l0_delta_layers.push(layer_desc.clone().into()); } @@ -483,7 +483,7 @@ impl LayerMap { self.historic .remove(historic_layer_coverage::LayerKey::from(layer_desc)); let layer_key = layer_desc.key(); - if Self::is_l0(&layer_desc.key_range) { + if Self::is_l0(&layer_desc.key_range, layer_desc.is_delta) { let len_before = self.l0_delta_layers.len(); let mut l0_delta_layers = std::mem::take(&mut self.l0_delta_layers); l0_delta_layers.retain(|other| other.key() != layer_key); @@ -600,8 +600,8 @@ impl LayerMap { } /// Check if the key range resembles that of an L0 layer. - pub fn is_l0(key_range: &Range) -> bool { - key_range == &(Key::MIN..Key::MAX) + pub fn is_l0(key_range: &Range, is_delta_layer: bool) -> bool { + is_delta_layer && key_range == &(Key::MIN..Key::MAX) } /// This function determines which layers are counted in `count_deltas`: @@ -628,7 +628,7 @@ impl LayerMap { /// than just the current partition_range. pub fn is_reimage_worthy(layer: &PersistentLayerDesc, partition_range: &Range) -> bool { // Case 1 - if !Self::is_l0(&layer.key_range) { + if !Self::is_l0(&layer.key_range, layer.is_delta) { return true; } diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index 6073abc8c3..190316df42 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -565,7 +565,7 @@ mod tests { ); let expected_bytes = vec![ /* TimelineMetadataHeader */ - 4, 37, 101, 34, 0, 70, 0, 4, // checksum, size, format_version (4 + 2 + 2) + 74, 104, 158, 105, 0, 70, 0, 4, // checksum, size, format_version (4 + 2 + 2) /* TimelineMetadataBodyV2 */ 0, 0, 0, 0, 0, 0, 2, 0, // disk_consistent_lsn (8 bytes) 1, 0, 0, 0, 0, 0, 0, 1, 0, // prev_record_lsn (9 bytes) @@ -574,7 +574,7 @@ mod tests { 0, 0, 0, 0, 0, 0, 0, 0, // ancestor_lsn (8 bytes) 0, 0, 0, 0, 0, 0, 0, 0, // latest_gc_cutoff_lsn (8 bytes) 0, 0, 0, 0, 0, 0, 0, 0, // initdb_lsn (8 bytes) - 0, 0, 0, 15, // pg_version (4 bytes) + 0, 0, 0, 16, // pg_version (4 bytes) /* padding bytes */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index b4d7ad1e97..71b766e4c7 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1728,7 +1728,7 @@ impl RemoteTimelineClient { task_mgr::spawn( &self.runtime, TaskKind::RemoteUploadTask, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), "remote upload", async move { diff --git a/pageserver/src/tenant/secondary.rs b/pageserver/src/tenant/secondary.rs index 3132a28b12..1331c07d05 100644 --- a/pageserver/src/tenant/secondary.rs +++ b/pageserver/src/tenant/secondary.rs @@ -8,6 +8,7 @@ use std::{sync::Arc, time::SystemTime}; use crate::{ context::RequestContext, disk_usage_eviction_task::DiskUsageEvictionInfo, + metrics::SECONDARY_HEATMAP_TOTAL_SIZE, task_mgr::{self, TaskKind, BACKGROUND_RUNTIME}, }; @@ -105,6 +106,9 @@ pub(crate) struct SecondaryTenant { // Sum of layer sizes on local disk pub(super) resident_size_metric: UIntGauge, + + // Sum of layer sizes in the most recently downloaded heatmap + pub(super) heatmap_total_size_metric: UIntGauge, } impl Drop for SecondaryTenant { @@ -112,6 +116,7 @@ impl Drop for SecondaryTenant { let tenant_id = self.tenant_shard_id.tenant_id.to_string(); let shard_id = format!("{}", self.tenant_shard_id.shard_slug()); let _ = SECONDARY_RESIDENT_PHYSICAL_SIZE.remove_label_values(&[&tenant_id, &shard_id]); + let _ = SECONDARY_HEATMAP_TOTAL_SIZE.remove_label_values(&[&tenant_id, &shard_id]); } } @@ -128,6 +133,10 @@ impl SecondaryTenant { .get_metric_with_label_values(&[&tenant_id, &shard_id]) .unwrap(); + let heatmap_total_size_metric = SECONDARY_HEATMAP_TOTAL_SIZE + .get_metric_with_label_values(&[&tenant_id, &shard_id]) + .unwrap(); + Arc::new(Self { tenant_shard_id, // todo: shall we make this a descendent of the @@ -145,6 +154,7 @@ impl SecondaryTenant { progress: std::sync::Mutex::default(), resident_size_metric, + heatmap_total_size_metric, }) } diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 8cff1d2864..90e1c01dbd 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -829,6 +829,12 @@ impl<'a> TenantDownloader<'a> { layers_downloaded: 0, bytes_downloaded: 0, }; + + // Also expose heatmap bytes_total as a metric + self.secondary_state + .heatmap_total_size_metric + .set(heatmap_stats.bytes); + // Accumulate list of things to delete while holding the detail lock, for execution after dropping the lock let mut delete_layers = Vec::new(); let mut delete_timelines = Vec::new(); diff --git a/pageserver/src/tenant/storage_layer.rs b/pageserver/src/tenant/storage_layer.rs index 04f89db401..a1202ad507 100644 --- a/pageserver/src/tenant/storage_layer.rs +++ b/pageserver/src/tenant/storage_layer.rs @@ -2,13 +2,12 @@ pub mod delta_layer; pub mod image_layer; -pub(crate) mod inmemory_layer; +pub mod inmemory_layer; pub(crate) mod layer; mod layer_desc; mod layer_name; pub mod merge_iterator; -#[cfg(test)] pub mod split_writer; use crate::context::{AccessStatsBehavior, RequestContext}; diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 6c2391d72d..f4a2957972 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -36,6 +36,7 @@ use crate::tenant::block_io::{BlockBuf, BlockCursor, BlockLease, BlockReader, Fi use crate::tenant::disk_btree::{ DiskBtreeBuilder, DiskBtreeIterator, DiskBtreeReader, VisitDirection, }; +use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT; use crate::tenant::timeline::GetVectoredError; use crate::tenant::vectored_blob_io::{ BlobFlag, MaxVectoredReadBytes, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead, @@ -232,6 +233,18 @@ pub struct DeltaLayerInner { max_vectored_read_bytes: Option, } +impl DeltaLayerInner { + pub(crate) fn layer_dbg_info(&self) -> String { + format!( + "delta {}..{} {}..{}", + self.key_range().start, + self.key_range().end, + self.lsn_range().start, + self.lsn_range().end + ) + } +} + impl std::fmt::Debug for DeltaLayerInner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("DeltaLayerInner") @@ -556,7 +569,6 @@ impl DeltaLayerWriterInner { // 5GB limit for objects without multipart upload (which we don't want to use) // Make it a little bit below to account for differing GB units // https://docs.aws.amazon.com/AmazonS3/latest/userguide/upload-objects.html - const S3_UPLOAD_LIMIT: u64 = 4_500_000_000; ensure!( metadata.len() <= S3_UPLOAD_LIMIT, "Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!", @@ -690,12 +702,10 @@ impl DeltaLayerWriter { self.inner.take().unwrap().finish(key_end, ctx).await } - #[cfg(test)] pub(crate) fn num_keys(&self) -> usize { self.inner.as_ref().unwrap().num_keys } - #[cfg(test)] pub(crate) fn estimated_size(&self) -> u64 { let inner = self.inner.as_ref().unwrap(); inner.blob_writer.size() + inner.tree.borrow_writer().size() + PAGE_SZ as u64 @@ -1527,6 +1537,10 @@ pub struct DeltaLayerIterator<'a> { } impl<'a> DeltaLayerIterator<'a> { + pub(crate) fn layer_dbg_info(&self) -> String { + self.delta_layer.layer_dbg_info() + } + /// Retrieve a batch of key-value pairs into the iterator buffer. async fn next_batch(&mut self) -> anyhow::Result<()> { assert!(self.key_values_batch.is_empty()); diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 9a19e4e2c7..3cb2b1c83a 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -167,6 +167,17 @@ pub struct ImageLayerInner { max_vectored_read_bytes: Option, } +impl ImageLayerInner { + pub(crate) fn layer_dbg_info(&self) -> String { + format!( + "image {}..{} {}", + self.key_range().start, + self.key_range().end, + self.lsn() + ) + } +} + impl std::fmt::Debug for ImageLayerInner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ImageLayerInner") @@ -705,10 +716,6 @@ struct ImageLayerWriterInner { } impl ImageLayerWriterInner { - fn size(&self) -> u64 { - self.tree.borrow_writer().size() + self.blob_writer.size() - } - /// /// Start building a new image layer. /// @@ -843,13 +850,19 @@ impl ImageLayerWriterInner { res?; } + let final_key_range = if let Some(end_key) = end_key { + self.key_range.start..end_key + } else { + self.key_range.clone() + }; + // Fill in the summary on blk 0 let summary = Summary { magic: IMAGE_FILE_MAGIC, format_version: STORAGE_FORMAT_VERSION, tenant_id: self.tenant_shard_id.tenant_id, timeline_id: self.timeline_id, - key_range: self.key_range.clone(), + key_range: final_key_range.clone(), lsn: self.lsn, index_start_blk, index_root_blk, @@ -870,11 +883,7 @@ impl ImageLayerWriterInner { let desc = PersistentLayerDesc::new_img( self.tenant_shard_id, self.timeline_id, - if let Some(end_key) = end_key { - self.key_range.start..end_key - } else { - self.key_range.clone() - }, + final_key_range, self.lsn, metadata.len(), ); @@ -963,14 +972,12 @@ impl ImageLayerWriter { self.inner.as_mut().unwrap().put_image(key, img, ctx).await } - #[cfg(test)] /// Estimated size of the image layer. pub(crate) fn estimated_size(&self) -> u64 { let inner = self.inner.as_ref().unwrap(); inner.blob_writer.size() + inner.tree.borrow_writer().size() + PAGE_SZ as u64 } - #[cfg(test)] pub(crate) fn num_keys(&self) -> usize { self.inner.as_ref().unwrap().num_keys } @@ -986,7 +993,6 @@ impl ImageLayerWriter { self.inner.take().unwrap().finish(timeline, ctx, None).await } - #[cfg(test)] /// Finish writing the image layer with an end key, used in [`super::split_writer::SplitImageLayerWriter`]. The end key determines the end of the image layer's covered range and is exclusive. pub(super) async fn finish_with_end_key( mut self, @@ -1000,10 +1006,6 @@ impl ImageLayerWriter { .finish(timeline, ctx, Some(end_key)) .await } - - pub(crate) fn size(&self) -> u64 { - self.inner.as_ref().unwrap().size() - } } impl Drop for ImageLayerWriter { @@ -1024,6 +1026,10 @@ pub struct ImageLayerIterator<'a> { } impl<'a> ImageLayerIterator<'a> { + pub(crate) fn layer_dbg_info(&self) -> String { + self.image_layer.layer_dbg_info() + } + /// Retrieve a batch of key-value pairs into the iterator buffer. async fn next_batch(&mut self) -> anyhow::Result<()> { assert!(self.key_values_batch.is_empty()); diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index 748d79c149..a71b4dd83b 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -13,7 +13,7 @@ use crate::tenant::ephemeral_file::EphemeralFile; use crate::tenant::timeline::GetVectoredError; use crate::tenant::PageReconstructError; use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt; -use crate::{l0_flush, page_cache, walrecord}; +use crate::{l0_flush, page_cache}; use anyhow::{anyhow, Result}; use camino::Utf8PathBuf; use pageserver_api::key::CompactKey; @@ -33,7 +33,7 @@ use std::fmt::Write; use std::ops::Range; use std::sync::atomic::Ordering as AtomicOrdering; use std::sync::atomic::{AtomicU64, AtomicUsize}; -use tokio::sync::{RwLock, RwLockWriteGuard}; +use tokio::sync::RwLock; use super::{ DeltaLayerWriter, PersistentLayerDesc, ValueReconstructSituation, ValuesReconstructState, @@ -249,9 +249,7 @@ impl InMemoryLayer { /// debugging function to print out the contents of the layer /// /// this is likely completly unused - pub async fn dump(&self, verbose: bool, ctx: &RequestContext) -> Result<()> { - let inner = self.inner.read().await; - + pub async fn dump(&self, _verbose: bool, _ctx: &RequestContext) -> Result<()> { let end_str = self.end_lsn_or_max(); println!( @@ -259,39 +257,6 @@ impl InMemoryLayer { self.timeline_id, self.start_lsn, end_str, ); - if !verbose { - return Ok(()); - } - - let cursor = inner.file.block_cursor(); - let mut buf = Vec::new(); - for (key, vec_map) in inner.index.iter() { - for (lsn, pos) in vec_map.as_slice() { - let mut desc = String::new(); - cursor.read_blob_into_buf(*pos, &mut buf, ctx).await?; - let val = Value::des(&buf); - match val { - Ok(Value::Image(img)) => { - write!(&mut desc, " img {} bytes", img.len())?; - } - Ok(Value::WalRecord(rec)) => { - let wal_desc = walrecord::describe_wal_record(&rec).unwrap(); - write!( - &mut desc, - " rec {} bytes will_init: {} {}", - buf.len(), - rec.will_init(), - wal_desc - )?; - } - Err(err) => { - write!(&mut desc, " DESERIALIZATION ERROR: {}", err)?; - } - } - println!(" key {} at {}: {}", key, lsn, desc); - } - } - Ok(()) } @@ -355,6 +320,82 @@ impl InMemoryLayer { } } +/// Offset of a particular Value within a serialized batch. +struct SerializedBatchOffset { + key: CompactKey, + lsn: Lsn, + /// offset in bytes from the start of the batch's buffer to the Value's serialized size header. + offset: u64, +} + +pub struct SerializedBatch { + /// Blobs serialized in EphemeralFile's native format, ready for passing to [`EphemeralFile::write_raw`]. + pub(crate) raw: Vec, + + /// Index of values in [`Self::raw`], using offsets relative to the start of the buffer. + offsets: Vec, + + /// The highest LSN of any value in the batch + pub(crate) max_lsn: Lsn, +} + +impl SerializedBatch { + /// Write a blob length in the internal format of the EphemeralFile + pub(crate) fn write_blob_length(len: usize, cursor: &mut std::io::Cursor>) { + use std::io::Write; + + if len < 0x80 { + // short one-byte length header + let len_buf = [len as u8]; + + cursor + .write_all(&len_buf) + .expect("Writing to Vec is infallible"); + } else { + let mut len_buf = u32::to_be_bytes(len as u32); + len_buf[0] |= 0x80; + cursor + .write_all(&len_buf) + .expect("Writing to Vec is infallible"); + } + } + + pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> Self { + // Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by + // [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`] + let buffer_size = batch.iter().map(|i| i.2).sum::() + 4 * batch.len(); + let mut cursor = std::io::Cursor::new(Vec::::with_capacity(buffer_size)); + + let mut offsets: Vec = Vec::with_capacity(batch.len()); + let mut max_lsn: Lsn = Lsn(0); + for (key, lsn, val_ser_size, val) in batch { + let relative_off = cursor.position(); + + Self::write_blob_length(val_ser_size, &mut cursor); + val.ser_into(&mut cursor) + .expect("Writing into in-memory buffer is infallible"); + + offsets.push(SerializedBatchOffset { + key, + lsn, + offset: relative_off, + }); + max_lsn = std::cmp::max(max_lsn, lsn); + } + + let buffer = cursor.into_inner(); + + // Assert that we didn't do any extra allocations while building buffer. + debug_assert!(buffer.len() <= buffer_size); + + Self { + raw: buffer, + offsets, + max_lsn, + } + } +} + fn inmem_layer_display(mut f: impl Write, start_lsn: Lsn, end_lsn: Lsn) -> std::fmt::Result { write!(f, "inmem-{:016X}-{:016X}", start_lsn.0, end_lsn.0) } @@ -415,37 +456,20 @@ impl InMemoryLayer { }) } - // Write operations - - /// Common subroutine of the public put_wal_record() and put_page_image() functions. - /// Adds the page version to the in-memory tree - pub async fn put_value( + // Write path. + pub async fn put_batch( &self, - key: CompactKey, - lsn: Lsn, - buf: &[u8], + serialized_batch: SerializedBatch, ctx: &RequestContext, ) -> Result<()> { let mut inner = self.inner.write().await; self.assert_writable(); - self.put_value_locked(&mut inner, key, lsn, buf, ctx).await - } - async fn put_value_locked( - &self, - locked_inner: &mut RwLockWriteGuard<'_, InMemoryLayerInner>, - key: CompactKey, - lsn: Lsn, - buf: &[u8], - ctx: &RequestContext, - ) -> Result<()> { - trace!("put_value key {} at {}/{}", key, self.timeline_id, lsn); - - let off = { - locked_inner + let base_off = { + inner .file - .write_blob( - buf, + .write_raw( + &serialized_batch.raw, &RequestContextBuilder::extend(ctx) .page_content_kind(PageContentKind::InMemoryLayer) .build(), @@ -453,15 +477,23 @@ impl InMemoryLayer { .await? }; - let vec_map = locked_inner.index.entry(key).or_default(); - let old = vec_map.append_or_update_last(lsn, off).unwrap().0; - if old.is_some() { - // We already had an entry for this LSN. That's odd.. - warn!("Key {} at {} already exists", key, lsn); + for SerializedBatchOffset { + key, + lsn, + offset: relative_off, + } in serialized_batch.offsets + { + let off = base_off + relative_off; + let vec_map = inner.index.entry(key).or_default(); + let old = vec_map.append_or_update_last(lsn, off).unwrap().0; + if old.is_some() { + // We already had an entry for this LSN. That's odd.. + warn!("Key {} at {} already exists", key, lsn); + } } - let size = locked_inner.file.len(); - locked_inner.resource_units.maybe_publish_size(size); + let size = inner.file.len(); + inner.resource_units.maybe_publish_size(size); Ok(()) } @@ -536,7 +568,6 @@ impl InMemoryLayer { use l0_flush::Inner; let _concurrency_permit = match l0_flush_global_state { - Inner::PageCached => None, Inner::Direct { semaphore, .. } => Some(semaphore.acquire().await), }; @@ -568,34 +599,6 @@ impl InMemoryLayer { .await?; match l0_flush_global_state { - l0_flush::Inner::PageCached => { - let ctx = RequestContextBuilder::extend(ctx) - .page_content_kind(PageContentKind::InMemoryLayer) - .build(); - - let mut buf = Vec::new(); - - let cursor = inner.file.block_cursor(); - - for (key, vec_map) in inner.index.iter() { - // Write all page versions - for (lsn, pos) in vec_map.as_slice() { - cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?; - let will_init = Value::des(&buf)?.will_init(); - let (tmp, res) = delta_layer_writer - .put_value_bytes( - Key::from_compact(*key), - *lsn, - buf.slice_len(), - will_init, - &ctx, - ) - .await; - res?; - buf = tmp.into_raw_slice().into_inner(); - } - } - } l0_flush::Inner::Direct { .. } => { let file_contents: Vec = inner.file.load_to_vec(ctx).await?; assert_eq!( diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index 774f97e1d9..53bb66b95e 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -35,6 +35,8 @@ mod tests; #[cfg(test)] mod failpoints; +pub const S3_UPLOAD_LIMIT: u64 = 4_500_000_000; + /// A Layer contains all data in a "rectangle" consisting of a range of keys and /// range of LSNs. /// @@ -1296,7 +1298,10 @@ impl LayerInner { lsn_end: lsn_range.end, remote: !resident, access_stats, - l0: crate::tenant::layer_map::LayerMap::is_l0(&self.layer_desc().key_range), + l0: crate::tenant::layer_map::LayerMap::is_l0( + &self.layer_desc().key_range, + self.layer_desc().is_delta, + ), } } else { let lsn = self.desc.image_layer_lsn(); diff --git a/pageserver/src/tenant/storage_layer/layer_name.rs b/pageserver/src/tenant/storage_layer/layer_name.rs index f33ca076ab..47ae556279 100644 --- a/pageserver/src/tenant/storage_layer/layer_name.rs +++ b/pageserver/src/tenant/storage_layer/layer_name.rs @@ -256,6 +256,10 @@ impl LayerName { LayerName::Delta(layer) => &layer.key_range, } } + + pub fn is_delta(&self) -> bool { + matches!(self, LayerName::Delta(_)) + } } impl fmt::Display for LayerName { diff --git a/pageserver/src/tenant/storage_layer/merge_iterator.rs b/pageserver/src/tenant/storage_layer/merge_iterator.rs index b4bd976033..d2c341e5ce 100644 --- a/pageserver/src/tenant/storage_layer/merge_iterator.rs +++ b/pageserver/src/tenant/storage_layer/merge_iterator.rs @@ -3,6 +3,7 @@ use std::{ collections::{binary_heap, BinaryHeap}, }; +use anyhow::bail; use pageserver_api::key::Key; use utils::lsn::Lsn; @@ -26,6 +27,13 @@ impl<'a> LayerRef<'a> { Self::Delta(x) => LayerIterRef::Delta(x.iter(ctx)), } } + + fn layer_dbg_info(&self) -> String { + match self { + Self::Image(x) => x.layer_dbg_info(), + Self::Delta(x) => x.layer_dbg_info(), + } + } } enum LayerIterRef<'a> { @@ -40,6 +48,13 @@ impl LayerIterRef<'_> { Self::Image(x) => x.next().await, } } + + fn layer_dbg_info(&self) -> String { + match self { + Self::Image(x) => x.layer_dbg_info(), + Self::Delta(x) => x.layer_dbg_info(), + } + } } /// This type plays several roles at once @@ -75,6 +90,11 @@ impl<'a> PeekableLayerIterRef<'a> { async fn next(&mut self) -> anyhow::Result> { let result = self.peeked.take(); self.peeked = self.iter.next().await?; + if let (Some((k1, l1, _)), Some((k2, l2, _))) = (&self.peeked, &result) { + if (k1, l1) < (k2, l2) { + bail!("iterator is not ordered: {}", self.iter.layer_dbg_info()); + } + } Ok(result) } } @@ -178,7 +198,12 @@ impl<'a> IteratorWrapper<'a> { let iter = PeekableLayerIterRef::create(iter).await?; if let Some((k1, l1, _)) = iter.peek() { let (k2, l2) = first_key_lower_bound; - debug_assert!((k1, l1) >= (k2, l2)); + if (k1, l1) < (k2, l2) { + bail!( + "layer key range did not include the first key in the layer: {}", + layer.layer_dbg_info() + ); + } } *self = Self::Loaded { iter }; Ok(()) diff --git a/pageserver/src/tenant/storage_layer/split_writer.rs b/pageserver/src/tenant/storage_layer/split_writer.rs index d7bfe48c60..df910b5ad9 100644 --- a/pageserver/src/tenant/storage_layer/split_writer.rs +++ b/pageserver/src/tenant/storage_layer/split_writer.rs @@ -1,4 +1,4 @@ -use std::{ops::Range, sync::Arc}; +use std::{future::Future, ops::Range, sync::Arc}; use bytes::Bytes; use pageserver_api::key::{Key, KEY_SIZE}; @@ -7,7 +7,32 @@ use utils::{id::TimelineId, lsn::Lsn, shard::TenantShardId}; use crate::tenant::storage_layer::Layer; use crate::{config::PageServerConf, context::RequestContext, repository::Value, tenant::Timeline}; -use super::{DeltaLayerWriter, ImageLayerWriter, ResidentLayer}; +use super::layer::S3_UPLOAD_LIMIT; +use super::{ + DeltaLayerWriter, ImageLayerWriter, PersistentLayerDesc, PersistentLayerKey, ResidentLayer, +}; + +pub(crate) enum SplitWriterResult { + Produced(ResidentLayer), + Discarded(PersistentLayerKey), +} + +#[cfg(test)] +impl SplitWriterResult { + fn into_resident_layer(self) -> ResidentLayer { + match self { + SplitWriterResult::Produced(layer) => layer, + SplitWriterResult::Discarded(_) => panic!("unexpected discarded layer"), + } + } + + fn into_discarded_layer(self) -> PersistentLayerKey { + match self { + SplitWriterResult::Produced(_) => panic!("unexpected produced layer"), + SplitWriterResult::Discarded(layer) => layer, + } + } +} /// An image writer that takes images and produces multiple image layers. The interface does not /// guarantee atomicity (i.e., if the image layer generation fails, there might be leftover files @@ -16,11 +41,12 @@ use super::{DeltaLayerWriter, ImageLayerWriter, ResidentLayer}; pub struct SplitImageLayerWriter { inner: ImageLayerWriter, target_layer_size: u64, - generated_layers: Vec, + generated_layers: Vec, conf: &'static PageServerConf, timeline_id: TimelineId, tenant_shard_id: TenantShardId, lsn: Lsn, + start_key: Key, } impl SplitImageLayerWriter { @@ -49,16 +75,22 @@ impl SplitImageLayerWriter { timeline_id, tenant_shard_id, lsn, + start_key, }) } - pub async fn put_image( + pub async fn put_image_with_discard_fn( &mut self, key: Key, img: Bytes, tline: &Arc, ctx: &RequestContext, - ) -> anyhow::Result<()> { + discard: D, + ) -> anyhow::Result<()> + where + D: FnOnce(&PersistentLayerKey) -> F, + F: Future, + { // The current estimation is an upper bound of the space that the key/image could take // because we did not consider compression in this estimation. The resulting image layer // could be smaller than the target size. @@ -76,33 +108,87 @@ impl SplitImageLayerWriter { ) .await?; let prev_image_writer = std::mem::replace(&mut self.inner, next_image_writer); - self.generated_layers.push( - prev_image_writer - .finish_with_end_key(tline, key, ctx) - .await?, - ); + let layer_key = PersistentLayerKey { + key_range: self.start_key..key, + lsn_range: PersistentLayerDesc::image_layer_lsn_range(self.lsn), + is_delta: false, + }; + self.start_key = key; + + if discard(&layer_key).await { + drop(prev_image_writer); + self.generated_layers + .push(SplitWriterResult::Discarded(layer_key)); + } else { + self.generated_layers.push(SplitWriterResult::Produced( + prev_image_writer + .finish_with_end_key(tline, key, ctx) + .await?, + )); + } } self.inner.put_image(key, img, ctx).await } - pub(crate) async fn finish( + #[cfg(test)] + pub async fn put_image( + &mut self, + key: Key, + img: Bytes, + tline: &Arc, + ctx: &RequestContext, + ) -> anyhow::Result<()> { + self.put_image_with_discard_fn(key, img, tline, ctx, |_| async { false }) + .await + } + + pub(crate) async fn finish_with_discard_fn( self, tline: &Arc, ctx: &RequestContext, end_key: Key, - ) -> anyhow::Result> { + discard: D, + ) -> anyhow::Result> + where + D: FnOnce(&PersistentLayerKey) -> F, + F: Future, + { let Self { mut generated_layers, inner, .. } = self; - generated_layers.push(inner.finish_with_end_key(tline, end_key, ctx).await?); + if inner.num_keys() == 0 { + return Ok(generated_layers); + } + let layer_key = PersistentLayerKey { + key_range: self.start_key..end_key, + lsn_range: PersistentLayerDesc::image_layer_lsn_range(self.lsn), + is_delta: false, + }; + if discard(&layer_key).await { + generated_layers.push(SplitWriterResult::Discarded(layer_key)); + } else { + generated_layers.push(SplitWriterResult::Produced( + inner.finish_with_end_key(tline, end_key, ctx).await?, + )); + } Ok(generated_layers) } + #[cfg(test)] + pub(crate) async fn finish( + self, + tline: &Arc, + ctx: &RequestContext, + end_key: Key, + ) -> anyhow::Result> { + self.finish_with_discard_fn(tline, ctx, end_key, |_| async { false }) + .await + } + /// When split writer fails, the caller should call this function and handle partially generated layers. - #[allow(dead_code)] - pub(crate) async fn take(self) -> anyhow::Result<(Vec, ImageLayerWriter)> { + pub(crate) fn take(self) -> anyhow::Result<(Vec, ImageLayerWriter)> { Ok((self.generated_layers, self.inner)) } } @@ -110,15 +196,21 @@ impl SplitImageLayerWriter { /// A delta writer that takes key-lsn-values and produces multiple delta layers. The interface does not /// guarantee atomicity (i.e., if the delta layer generation fails, there might be leftover files /// to be cleaned up). +/// +/// Note that if updates of a single key exceed the target size limit, all of the updates will be batched +/// into a single file. This behavior might change in the future. For reference, the legacy compaction algorithm +/// will split them into multiple files based on size. #[must_use] pub struct SplitDeltaLayerWriter { inner: DeltaLayerWriter, target_layer_size: u64, - generated_layers: Vec, + generated_layers: Vec, conf: &'static PageServerConf, timeline_id: TimelineId, tenant_shard_id: TenantShardId, lsn_range: Range, + last_key_written: Key, + start_key: Key, } impl SplitDeltaLayerWriter { @@ -147,9 +239,74 @@ impl SplitDeltaLayerWriter { timeline_id, tenant_shard_id, lsn_range, + last_key_written: Key::MIN, + start_key, }) } + /// Put value into the layer writer. In the case the writer decides to produce a layer, and the discard fn returns true, no layer will be written in the end. + pub async fn put_value_with_discard_fn( + &mut self, + key: Key, + lsn: Lsn, + val: Value, + tline: &Arc, + ctx: &RequestContext, + discard: D, + ) -> anyhow::Result<()> + where + D: FnOnce(&PersistentLayerKey) -> F, + F: Future, + { + // The current estimation is key size plus LSN size plus value size estimation. This is not an accurate + // number, and therefore the final layer size could be a little bit larger or smaller than the target. + // + // Also, keep all updates of a single key in a single file. TODO: split them using the legacy compaction + // strategy. https://github.com/neondatabase/neon/issues/8837 + let addition_size_estimation = KEY_SIZE as u64 + 8 /* LSN u64 size */ + 80 /* value size estimation */; + if self.inner.num_keys() >= 1 + && self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size + { + if key != self.last_key_written { + let next_delta_writer = DeltaLayerWriter::new( + self.conf, + self.timeline_id, + self.tenant_shard_id, + key, + self.lsn_range.clone(), + ctx, + ) + .await?; + let prev_delta_writer = std::mem::replace(&mut self.inner, next_delta_writer); + let layer_key = PersistentLayerKey { + key_range: self.start_key..key, + lsn_range: self.lsn_range.clone(), + is_delta: true, + }; + self.start_key = key; + if discard(&layer_key).await { + drop(prev_delta_writer); + self.generated_layers + .push(SplitWriterResult::Discarded(layer_key)); + } else { + let (desc, path) = prev_delta_writer.finish(key, ctx).await?; + let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?; + self.generated_layers + .push(SplitWriterResult::Produced(delta_layer)); + } + } else if self.inner.estimated_size() >= S3_UPLOAD_LIMIT { + // We have to produce a very large file b/c a key is updated too often. + anyhow::bail!( + "a single key is updated too often: key={}, estimated_size={}, and the layer file cannot be produced", + key, + self.inner.estimated_size() + ); + } + } + self.last_key_written = key; + self.inner.put_value(key, lsn, val, ctx).await + } + pub async fn put_value( &mut self, key: Key, @@ -158,56 +315,66 @@ impl SplitDeltaLayerWriter { tline: &Arc, ctx: &RequestContext, ) -> anyhow::Result<()> { - // The current estimation is key size plus LSN size plus value size estimation. This is not an accurate - // number, and therefore the final layer size could be a little bit larger or smaller than the target. - let addition_size_estimation = KEY_SIZE as u64 + 8 /* LSN u64 size */ + 80 /* value size estimation */; - if self.inner.num_keys() >= 1 - && self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size - { - let next_delta_writer = DeltaLayerWriter::new( - self.conf, - self.timeline_id, - self.tenant_shard_id, - key, - self.lsn_range.clone(), - ctx, - ) - .await?; - let prev_delta_writer = std::mem::replace(&mut self.inner, next_delta_writer); - let (desc, path) = prev_delta_writer.finish(key, ctx).await?; - let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?; - self.generated_layers.push(delta_layer); - } - self.inner.put_value(key, lsn, val, ctx).await + self.put_value_with_discard_fn(key, lsn, val, tline, ctx, |_| async { false }) + .await } - pub(crate) async fn finish( + pub(crate) async fn finish_with_discard_fn( self, tline: &Arc, ctx: &RequestContext, end_key: Key, - ) -> anyhow::Result> { + discard: D, + ) -> anyhow::Result> + where + D: FnOnce(&PersistentLayerKey) -> F, + F: Future, + { let Self { mut generated_layers, inner, .. } = self; - - let (desc, path) = inner.finish(end_key, ctx).await?; - let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?; - generated_layers.push(delta_layer); + if inner.num_keys() == 0 { + return Ok(generated_layers); + } + let layer_key = PersistentLayerKey { + key_range: self.start_key..end_key, + lsn_range: self.lsn_range.clone(), + is_delta: true, + }; + if discard(&layer_key).await { + generated_layers.push(SplitWriterResult::Discarded(layer_key)); + } else { + let (desc, path) = inner.finish(end_key, ctx).await?; + let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?; + generated_layers.push(SplitWriterResult::Produced(delta_layer)); + } Ok(generated_layers) } - /// When split writer fails, the caller should call this function and handle partially generated layers. #[allow(dead_code)] - pub(crate) async fn take(self) -> anyhow::Result<(Vec, DeltaLayerWriter)> { + pub(crate) async fn finish( + self, + tline: &Arc, + ctx: &RequestContext, + end_key: Key, + ) -> anyhow::Result> { + self.finish_with_discard_fn(tline, ctx, end_key, |_| async { false }) + .await + } + + /// When split writer fails, the caller should call this function and handle partially generated layers. + pub(crate) fn take(self) -> anyhow::Result<(Vec, DeltaLayerWriter)> { Ok((self.generated_layers, self.inner)) } } #[cfg(test)] mod tests { + use itertools::Itertools; + use rand::{RngCore, SeedableRng}; + use crate::{ tenant::{ harness::{TenantHarness, TIMELINE_ID}, @@ -229,7 +396,10 @@ mod tests { } fn get_large_img() -> Bytes { - vec![0; 8192].into() + let mut rng = rand::rngs::SmallRng::seed_from_u64(42); + let mut data = vec![0; 8192]; + rng.fill_bytes(&mut data); + data.into() } #[tokio::test] @@ -297,9 +467,16 @@ mod tests { #[tokio::test] async fn write_split() { - let harness = TenantHarness::create("split_writer_write_split") - .await - .unwrap(); + write_split_helper("split_writer_write_split", false).await; + } + + #[tokio::test] + async fn write_split_discard() { + write_split_helper("split_writer_write_split_discard", false).await; + } + + async fn write_split_helper(harness_name: &'static str, discard: bool) { + let harness = TenantHarness::create(harness_name).await.unwrap(); let (tenant, ctx) = harness.load().await; let tline = tenant @@ -333,16 +510,19 @@ mod tests { for i in 0..N { let i = i as u32; image_writer - .put_image(get_key(i), get_large_img(), &tline, &ctx) + .put_image_with_discard_fn(get_key(i), get_large_img(), &tline, &ctx, |_| async { + discard + }) .await .unwrap(); delta_writer - .put_value( + .put_value_with_discard_fn( get_key(i), Lsn(0x20), Value::Image(get_large_img()), &tline, &ctx, + |_| async { discard }, ) .await .unwrap(); @@ -355,22 +535,39 @@ mod tests { .finish(&tline, &ctx, get_key(N as u32)) .await .unwrap(); - assert_eq!(image_layers.len(), N / 512 + 1); - assert_eq!(delta_layers.len(), N / 512 + 1); - for idx in 0..image_layers.len() { - assert_ne!(image_layers[idx].layer_desc().key_range.start, Key::MIN); - assert_ne!(image_layers[idx].layer_desc().key_range.end, Key::MAX); - assert_ne!(delta_layers[idx].layer_desc().key_range.start, Key::MIN); - assert_ne!(delta_layers[idx].layer_desc().key_range.end, Key::MAX); - if idx > 0 { - assert_eq!( - image_layers[idx - 1].layer_desc().key_range.end, - image_layers[idx].layer_desc().key_range.start - ); - assert_eq!( - delta_layers[idx - 1].layer_desc().key_range.end, - delta_layers[idx].layer_desc().key_range.start - ); + if discard { + for layer in image_layers { + layer.into_discarded_layer(); + } + for layer in delta_layers { + layer.into_discarded_layer(); + } + } else { + let image_layers = image_layers + .into_iter() + .map(|x| x.into_resident_layer()) + .collect_vec(); + let delta_layers = delta_layers + .into_iter() + .map(|x| x.into_resident_layer()) + .collect_vec(); + assert_eq!(image_layers.len(), N / 512 + 1); + assert_eq!(delta_layers.len(), N / 512 + 1); + for idx in 0..image_layers.len() { + assert_ne!(image_layers[idx].layer_desc().key_range.start, Key::MIN); + assert_ne!(image_layers[idx].layer_desc().key_range.end, Key::MAX); + assert_ne!(delta_layers[idx].layer_desc().key_range.start, Key::MIN); + assert_ne!(delta_layers[idx].layer_desc().key_range.end, Key::MAX); + if idx > 0 { + assert_eq!( + image_layers[idx - 1].layer_desc().key_range.end, + image_layers[idx].layer_desc().key_range.start + ); + assert_eq!( + delta_layers[idx - 1].layer_desc().key_range.end, + delta_layers[idx].layer_desc().key_range.start + ); + } } } } @@ -451,4 +648,49 @@ mod tests { .unwrap(); assert_eq!(layers.len(), 2); } + + #[tokio::test] + async fn write_split_single_key() { + let harness = TenantHarness::create("split_writer_write_split_single_key") + .await + .unwrap(); + let (tenant, ctx) = harness.load().await; + + let tline = tenant + .create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx) + .await + .unwrap(); + + const N: usize = 2000; + let mut delta_writer = SplitDeltaLayerWriter::new( + tenant.conf, + tline.timeline_id, + tenant.tenant_shard_id, + get_key(0), + Lsn(0x10)..Lsn(N as u64 * 16 + 0x10), + 4 * 1024 * 1024, + &ctx, + ) + .await + .unwrap(); + + for i in 0..N { + let i = i as u32; + delta_writer + .put_value( + get_key(0), + Lsn(i as u64 * 16 + 0x10), + Value::Image(get_large_img()), + &tline, + &ctx, + ) + .await + .unwrap(); + } + let delta_layers = delta_writer + .finish(&tline, &ctx, get_key(N as u32)) + .await + .unwrap(); + assert_eq!(delta_layers.len(), 1); + } } diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index dbcd704b4e..12f080f3c1 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -61,21 +61,12 @@ impl BackgroundLoopKind { } } -static PERMIT_GAUGES: once_cell::sync::Lazy< - enum_map::EnumMap, -> = once_cell::sync::Lazy::new(|| { - enum_map::EnumMap::from_array(std::array::from_fn(|i| { - let kind = ::from_usize(i); - crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_GAUGE.with_label_values(&[kind.into()]) - })) -}); - /// Cancellation safe. pub(crate) async fn concurrent_background_tasks_rate_limit_permit( loop_kind: BackgroundLoopKind, _ctx: &RequestContext, ) -> tokio::sync::SemaphorePermit<'static> { - let _guard = PERMIT_GAUGES[loop_kind].guard(); + let _guard = crate::metrics::BACKGROUND_LOOP_SEMAPHORE.measure_acquisition(loop_kind); pausable_failpoint!( "initial-size-calculation-permit-pause", @@ -98,7 +89,7 @@ pub fn start_background_loops( task_mgr::spawn( BACKGROUND_RUNTIME.handle(), TaskKind::Compaction, - Some(tenant_shard_id), + tenant_shard_id, None, &format!("compactor for tenant {tenant_shard_id}"), { @@ -121,7 +112,7 @@ pub fn start_background_loops( task_mgr::spawn( BACKGROUND_RUNTIME.handle(), TaskKind::GarbageCollector, - Some(tenant_shard_id), + tenant_shard_id, None, &format!("garbage collector for tenant {tenant_shard_id}"), { @@ -144,7 +135,7 @@ pub fn start_background_loops( task_mgr::spawn( BACKGROUND_RUNTIME.handle(), TaskKind::IngestHousekeeping, - Some(tenant_shard_id), + tenant_shard_id, None, &format!("ingest housekeeping for tenant {tenant_shard_id}"), { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 26dc87c373..098c196ee8 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -22,8 +22,8 @@ use handle::ShardTimelineId; use once_cell::sync::Lazy; use pageserver_api::{ key::{ - KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX, NON_INHERITED_RANGE, - NON_INHERITED_SPARSE_RANGE, + CompactKey, KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX, + NON_INHERITED_RANGE, NON_INHERITED_SPARSE_RANGE, }, keyspace::{KeySpaceAccum, KeySpaceRandomAccum, SparseKeyPartitioning}, models::{ @@ -44,10 +44,8 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::*; use utils::{ - bin_ser::BeSer, fs_ext, pausable_failpoint, sync::gate::{Gate, GateGuard}, - vec_map::VecMap, }; use std::pin::pin; @@ -137,7 +135,10 @@ use self::layer_manager::LayerManager; use self::logical_size::LogicalSize; use self::walreceiver::{WalReceiver, WalReceiverConf}; -use super::{config::TenantConf, storage_layer::LayerVisibilityHint, upload_queue::NotInitialized}; +use super::{ + config::TenantConf, storage_layer::inmemory_layer, storage_layer::LayerVisibilityHint, + upload_queue::NotInitialized, +}; use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf}; use super::{remote_timeline_client::index::IndexPart, storage_layer::LayerFringe}; use super::{ @@ -2233,6 +2234,11 @@ impl Timeline { handles: Default::default(), }; + + if aux_file_policy == Some(AuxFilePolicy::V1) { + warn!("this timeline is using deprecated aux file policy V1"); + } + result.repartition_threshold = result.get_checkpoint_distance() / REPARTITION_FREQ_IN_CHECKPOINT_DISTANCE; @@ -2281,7 +2287,7 @@ impl Timeline { task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::LayerFlushTask, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), "layer flush task", async move { @@ -2635,7 +2641,7 @@ impl Timeline { task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::InitialLogicalSizeCalculation, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), "initial size calculation", // NB: don't log errors here, task_mgr will do that. @@ -2803,7 +2809,7 @@ impl Timeline { task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::OndemandLogicalSizeCalculation, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), "ondemand logical size calculation", async move { @@ -2996,7 +3002,10 @@ impl Timeline { // - For L1 & image layers, download most recent LSNs first: the older the LSN, the sooner // the layer is likely to be covered by an image layer during compaction. layers.sort_by_key(|(desc, _meta, _atime)| { - std::cmp::Reverse((!LayerMap::is_l0(&desc.key_range), desc.lsn_range.end)) + std::cmp::Reverse(( + !LayerMap::is_l0(&desc.key_range, desc.is_delta), + desc.lsn_range.end, + )) }); let layers = layers @@ -3589,34 +3598,6 @@ impl Timeline { return Err(FlushLayerError::Cancelled); } - // FIXME(auxfilesv2): support multiple metadata key partitions might need initdb support as well? - // This code path will not be hit during regression tests. After #7099 we have a single partition - // with two key ranges. If someone wants to fix initdb optimization in the future, this might need - // to be fixed. - - // For metadata, always create delta layers. - let delta_layer = if !metadata_partition.parts.is_empty() { - assert_eq!( - metadata_partition.parts.len(), - 1, - "currently sparse keyspace should only contain a single metadata keyspace" - ); - let metadata_keyspace = &metadata_partition.parts[0]; - self.create_delta_layer( - &frozen_layer, - Some( - metadata_keyspace.0.ranges.first().unwrap().start - ..metadata_keyspace.0.ranges.last().unwrap().end, - ), - ctx, - ) - .await - .map_err(|e| FlushLayerError::from_anyhow(self, e))? - } else { - None - }; - - // For image layers, we add them immediately into the layer map. let mut layers_to_upload = Vec::new(); layers_to_upload.extend( self.create_image_layers( @@ -3627,13 +3608,27 @@ impl Timeline { ) .await?, ); - - if let Some(delta_layer) = delta_layer { - layers_to_upload.push(delta_layer.clone()); - (layers_to_upload, Some(delta_layer)) - } else { - (layers_to_upload, None) + if !metadata_partition.parts.is_empty() { + assert_eq!( + metadata_partition.parts.len(), + 1, + "currently sparse keyspace should only contain a single metadata keyspace" + ); + layers_to_upload.extend( + self.create_image_layers( + // Safety: create_image_layers treat sparse keyspaces differently that it does not scan + // every single key within the keyspace, and therefore, it's safe to force converting it + // into a dense keyspace before calling this function. + &metadata_partition.into_dense(), + self.initdb_lsn, + ImageLayerCreationMode::Initial, + ctx, + ) + .await?, + ); } + + (layers_to_upload, None) } else { // Normal case, write out a L0 delta layer file. // `create_delta_layer` will not modify the layer map. @@ -4043,8 +4038,6 @@ impl Timeline { mode: ImageLayerCreationMode, start: Key, ) -> Result { - assert!(!matches!(mode, ImageLayerCreationMode::Initial)); - // Metadata keys image layer creation. let mut reconstruct_state = ValuesReconstructState::default(); let data = self @@ -4210,15 +4203,13 @@ impl Timeline { "metadata keys must be partitioned separately" ); } - if mode == ImageLayerCreationMode::Initial { - return Err(CreateImageLayersError::Other(anyhow::anyhow!("no image layer should be created for metadata keys when flushing frozen layers"))); - } if mode == ImageLayerCreationMode::Try && !check_for_image_layers { // Skip compaction if there are not enough updates. Metadata compaction will do a scan and // might mess up with evictions. start = img_range.end; continue; } + // For initial and force modes, we always generate image layers for metadata keys. } else if let ImageLayerCreationMode::Try = mode { // check_for_image_layers = false -> skip // check_for_image_layers = true -> check time_for_new_image_layer -> skip/generate @@ -4226,7 +4217,8 @@ impl Timeline { start = img_range.end; continue; } - } else if let ImageLayerCreationMode::Force = mode { + } + if let ImageLayerCreationMode::Force = mode { // When forced to create image layers, we might try and create them where they already // exist. This mode is only used in tests/debug. let layers = self.layers.read().await; @@ -4240,6 +4232,7 @@ impl Timeline { img_range.start, img_range.end ); + start = img_range.end; continue; } } @@ -4595,7 +4588,7 @@ impl Timeline { // for compact_level0_phase1 creating an L0, which does not happen in practice // because we have not implemented L0 => L0 compaction. duplicated_layers.insert(l.layer_desc().key()); - } else if LayerMap::is_l0(&l.layer_desc().key_range) { + } else if LayerMap::is_l0(&l.layer_desc().key_range, l.layer_desc().is_delta) { return Err(CompactionError::Other(anyhow::anyhow!("compaction generates a L0 layer file as output, which will cause infinite compaction."))); } else { insert_layers.push(l.clone()); @@ -5162,7 +5155,7 @@ impl Timeline { let task_id = task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::DownloadAllRemoteLayers, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), "download all remote layers task", async move { @@ -5451,12 +5444,17 @@ impl Timeline { !(a.end <= b.start || b.end <= a.start) } - let guard = self.layers.read().await; - for layer in guard.layer_map()?.iter_historic_layers() { - if layer.is_delta() - && overlaps_with(&layer.lsn_range, &deltas.lsn_range) - && layer.lsn_range != deltas.lsn_range - { + if deltas.key_range.start.next() != deltas.key_range.end { + let guard = self.layers.read().await; + let mut invalid_layers = + guard.layer_map()?.iter_historic_layers().filter(|layer| { + layer.is_delta() + && overlaps_with(&layer.lsn_range, &deltas.lsn_range) + && layer.lsn_range != deltas.lsn_range + // skip single-key layer files + && layer.key_range.start.next() != layer.key_range.end + }); + if let Some(layer) = invalid_layers.next() { // If a delta layer overlaps with another delta layer AND their LSN range is not the same, panic panic!( "inserted layer violates delta layer LSN invariant: current_lsn_range={}..{}, conflict_lsn_range={}..{}", @@ -5590,44 +5588,6 @@ enum OpenLayerAction { } impl<'a> TimelineWriter<'a> { - /// Put a new page version that can be constructed from a WAL record - /// - /// This will implicitly extend the relation, if the page is beyond the - /// current end-of-file. - pub(crate) async fn put( - &mut self, - key: Key, - lsn: Lsn, - value: &Value, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - // Avoid doing allocations for "small" values. - // In the regression test suite, the limit of 256 avoided allocations in 95% of cases: - // https://github.com/neondatabase/neon/pull/5056#discussion_r1301975061 - let mut buf = smallvec::SmallVec::<[u8; 256]>::new(); - value.ser_into(&mut buf)?; - let buf_size: u64 = buf.len().try_into().expect("oversized value buf"); - - let action = self.get_open_layer_action(lsn, buf_size); - let layer = self.handle_open_layer_action(lsn, action, ctx).await?; - let res = layer.put_value(key.to_compact(), lsn, &buf, ctx).await; - - if res.is_ok() { - // Update the current size only when the entire write was ok. - // In case of failures, we may have had partial writes which - // render the size tracking out of sync. That's ok because - // the checkpoint distance should be significantly smaller - // than the S3 single shot upload limit of 5GiB. - let state = self.write_guard.as_mut().unwrap(); - - state.current_size += buf_size; - state.prev_lsn = Some(lsn); - state.max_lsn = std::cmp::max(state.max_lsn, Some(lsn)); - } - - res - } - async fn handle_open_layer_action( &mut self, at: Lsn, @@ -5733,18 +5693,58 @@ impl<'a> TimelineWriter<'a> { } /// Put a batch of keys at the specified Lsns. - /// - /// The batch is sorted by Lsn (enforced by usage of [`utils::vec_map::VecMap`]. pub(crate) async fn put_batch( &mut self, - batch: VecMap, + batch: Vec<(CompactKey, Lsn, usize, Value)>, ctx: &RequestContext, ) -> anyhow::Result<()> { - for (lsn, (key, val)) in batch { - self.put(key, lsn, &val, ctx).await? + if batch.is_empty() { + return Ok(()); } - Ok(()) + let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch); + let batch_max_lsn = serialized_batch.max_lsn; + let buf_size: u64 = serialized_batch.raw.len() as u64; + + let action = self.get_open_layer_action(batch_max_lsn, buf_size); + let layer = self + .handle_open_layer_action(batch_max_lsn, action, ctx) + .await?; + + let res = layer.put_batch(serialized_batch, ctx).await; + + if res.is_ok() { + // Update the current size only when the entire write was ok. + // In case of failures, we may have had partial writes which + // render the size tracking out of sync. That's ok because + // the checkpoint distance should be significantly smaller + // than the S3 single shot upload limit of 5GiB. + let state = self.write_guard.as_mut().unwrap(); + + state.current_size += buf_size; + state.prev_lsn = Some(batch_max_lsn); + state.max_lsn = std::cmp::max(state.max_lsn, Some(batch_max_lsn)); + } + + res + } + + #[cfg(test)] + /// Test helper, for tests that would like to poke individual values without composing a batch + pub(crate) async fn put( + &mut self, + key: Key, + lsn: Lsn, + value: &Value, + ctx: &RequestContext, + ) -> anyhow::Result<()> { + use utils::bin_ser::BeSer; + let val_ser_size = value.serialized_size().unwrap() as usize; + self.put_batch( + vec![(key.to_compact(), lsn, val_ser_size, value.clone())], + ctx, + ) + .await } pub(crate) async fn delete_batch( @@ -5885,7 +5885,7 @@ mod tests { }; // Apart from L0s, newest Layers should come first - if !LayerMap::is_l0(layer.name.key_range()) { + if !LayerMap::is_l0(layer.name.key_range(), layer.name.is_delta()) { assert!(layer_lsn <= last_lsn); last_lsn = layer_lsn; } diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 7370ec1386..aad75ac59c 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -14,7 +14,7 @@ use super::{ RecordedDuration, Timeline, }; -use anyhow::{anyhow, Context}; +use anyhow::{anyhow, bail, Context}; use bytes::Bytes; use enumset::EnumSet; use fail::fail_point; @@ -32,6 +32,9 @@ use crate::page_cache; use crate::tenant::config::defaults::{DEFAULT_CHECKPOINT_DISTANCE, DEFAULT_COMPACTION_THRESHOLD}; use crate::tenant::remote_timeline_client::WaitCompletionError; use crate::tenant::storage_layer::merge_iterator::MergeIterator; +use crate::tenant::storage_layer::split_writer::{ + SplitDeltaLayerWriter, SplitImageLayerWriter, SplitWriterResult, +}; use crate::tenant::storage_layer::{ AsLayerDesc, PersistentLayerDesc, PersistentLayerKey, ValueReconstructState, }; @@ -71,15 +74,60 @@ pub(crate) struct KeyHistoryRetention { } impl KeyHistoryRetention { + /// Hack: skip delta layer if we need to produce a layer of a same key-lsn. + /// + /// This can happen if we have removed some deltas in "the middle" of some existing layer's key-lsn-range. + /// For example, consider the case where a single delta with range [0x10,0x50) exists. + /// And we have branches at LSN 0x10, 0x20, 0x30. + /// Then we delete branch @ 0x20. + /// Bottom-most compaction may now delete the delta [0x20,0x30). + /// And that wouldnt' change the shape of the layer. + /// + /// Note that bottom-most-gc-compaction never _adds_ new data in that case, only removes. + /// + /// `discard_key` will only be called when the writer reaches its target (instead of for every key), so it's fine to grab a lock inside. + async fn discard_key(key: &PersistentLayerKey, tline: &Arc, dry_run: bool) -> bool { + if dry_run { + return true; + } + let guard = tline.layers.read().await; + if !guard.contains_key(key) { + return false; + } + let layer_generation = guard.get_from_key(key).metadata().generation; + drop(guard); + if layer_generation == tline.generation { + info!( + key=%key, + ?layer_generation, + "discard layer due to duplicated layer key in the same generation", + ); + true + } else { + false + } + } + + /// Pipe a history of a single key to the writers. + /// + /// If `image_writer` is none, the images will be placed into the delta layers. + /// The delta writer will contain all images and deltas (below and above the horizon) except the bottom-most images. + #[allow(clippy::too_many_arguments)] async fn pipe_to( self, key: Key, - delta_writer: &mut Vec<(Key, Lsn, Value)>, - mut image_writer: Option<&mut ImageLayerWriter>, + tline: &Arc, + delta_writer: &mut SplitDeltaLayerWriter, + mut image_writer: Option<&mut SplitImageLayerWriter>, stat: &mut CompactionStatistics, + dry_run: bool, ctx: &RequestContext, ) -> anyhow::Result<()> { let mut first_batch = true; + let discard = |key: &PersistentLayerKey| { + let key = key.clone(); + async move { Self::discard_key(&key, tline, dry_run).await } + }; for (cutoff_lsn, KeyLogAtLsn(logs)) in self.below_horizon { if first_batch { if logs.len() == 1 && logs[0].1.is_image() { @@ -88,28 +136,45 @@ impl KeyHistoryRetention { }; stat.produce_image_key(img); if let Some(image_writer) = image_writer.as_mut() { - image_writer.put_image(key, img.clone(), ctx).await?; + image_writer + .put_image_with_discard_fn(key, img.clone(), tline, ctx, discard) + .await?; } else { - delta_writer.push((key, cutoff_lsn, Value::Image(img.clone()))); + delta_writer + .put_value_with_discard_fn( + key, + cutoff_lsn, + Value::Image(img.clone()), + tline, + ctx, + discard, + ) + .await?; } } else { for (lsn, val) in logs { stat.produce_key(&val); - delta_writer.push((key, lsn, val)); + delta_writer + .put_value_with_discard_fn(key, lsn, val, tline, ctx, discard) + .await?; } } first_batch = false; } else { for (lsn, val) in logs { stat.produce_key(&val); - delta_writer.push((key, lsn, val)); + delta_writer + .put_value_with_discard_fn(key, lsn, val, tline, ctx, discard) + .await?; } } } let KeyLogAtLsn(above_horizon_logs) = self.above_horizon; for (lsn, val) in above_horizon_logs { stat.produce_key(&val); - delta_writer.push((key, lsn, val)); + delta_writer + .put_value_with_discard_fn(key, lsn, val, tline, ctx, discard) + .await?; } Ok(()) } @@ -1814,11 +1879,27 @@ impl Timeline { } let mut selected_layers = Vec::new(); drop(gc_info); + // Pick all the layers intersect or below the gc_cutoff, get the largest LSN in the selected layers. + let Some(max_layer_lsn) = layers + .iter_historic_layers() + .filter(|desc| desc.get_lsn_range().start <= gc_cutoff) + .map(|desc| desc.get_lsn_range().end) + .max() + else { + info!("no layers to compact with gc"); + return Ok(()); + }; + // Then, pick all the layers that are below the max_layer_lsn. This is to ensure we can pick all single-key + // layers to compact. for desc in layers.iter_historic_layers() { - if desc.get_lsn_range().start <= gc_cutoff { + if desc.get_lsn_range().end <= max_layer_lsn { selected_layers.push(guard.get_from_desc(&desc)); } } + if selected_layers.is_empty() { + info!("no layers to compact with gc"); + return Ok(()); + } retain_lsns_below_horizon.sort(); (selected_layers, gc_cutoff, retain_lsns_below_horizon) }; @@ -1848,27 +1929,53 @@ impl Timeline { lowest_retain_lsn ); // Step 1: (In the future) construct a k-merge iterator over all layers. For now, simply collect all keys + LSNs. - // Also, collect the layer information to decide when to split the new delta layers. - let mut downloaded_layers = Vec::new(); - let mut delta_split_points = BTreeSet::new(); + // Also, verify if the layer map can be split by drawing a horizontal line at every LSN start/end split point. + let mut lsn_split_point = BTreeSet::new(); // TODO: use a better data structure (range tree / range set?) for layer in &layer_selection { - let resident_layer = layer.download_and_keep_resident().await?; - downloaded_layers.push(resident_layer); - let desc = layer.layer_desc(); if desc.is_delta() { - // TODO: is it correct to only record split points for deltas intersecting with the GC horizon? (exclude those below/above the horizon) - // so that we can avoid having too many small delta layers. - let key_range = desc.get_key_range(); - delta_split_points.insert(key_range.start); - delta_split_points.insert(key_range.end); + // ignore single-key layer files + if desc.key_range.start.next() != desc.key_range.end { + let lsn_range = &desc.lsn_range; + lsn_split_point.insert(lsn_range.start); + lsn_split_point.insert(lsn_range.end); + } stat.visit_delta_layer(desc.file_size()); } else { stat.visit_image_layer(desc.file_size()); } } + for layer in &layer_selection { + let desc = layer.layer_desc(); + let key_range = &desc.key_range; + if desc.is_delta() && key_range.start.next() != key_range.end { + let lsn_range = desc.lsn_range.clone(); + let intersects = lsn_split_point.range(lsn_range).collect_vec(); + if intersects.len() > 1 { + bail!( + "cannot run gc-compaction because it violates the layer map LSN split assumption: layer {} intersects with LSN [{}]", + desc.key(), + intersects.into_iter().map(|lsn| lsn.to_string()).join(", ") + ); + } + } + } + // The maximum LSN we are processing in this compaction loop + let end_lsn = layer_selection + .iter() + .map(|l| l.layer_desc().lsn_range.end) + .max() + .unwrap(); + // We don't want any of the produced layers to cover the full key range (i.e., MIN..MAX) b/c it will then be recognized + // as an L0 layer. + let hack_end_key = Key::NON_L0_MAX; let mut delta_layers = Vec::new(); let mut image_layers = Vec::new(); + let mut downloaded_layers = Vec::new(); + for layer in &layer_selection { + let resident_layer = layer.download_and_keep_resident().await?; + downloaded_layers.push(resident_layer); + } for resident_layer in &downloaded_layers { if resident_layer.layer_desc().is_delta() { let layer = resident_layer.get_as_delta(ctx).await?; @@ -1884,138 +1991,17 @@ impl Timeline { let mut accumulated_values = Vec::new(); let mut last_key: Option = None; - enum FlushDeltaResult { - /// Create a new resident layer - CreateResidentLayer(ResidentLayer), - /// Keep an original delta layer - KeepLayer(PersistentLayerKey), - } - - #[allow(clippy::too_many_arguments)] - async fn flush_deltas( - deltas: &mut Vec<(Key, Lsn, crate::repository::Value)>, - last_key: Key, - delta_split_points: &[Key], - current_delta_split_point: &mut usize, - tline: &Arc, - lowest_retain_lsn: Lsn, - ctx: &RequestContext, - stats: &mut CompactionStatistics, - dry_run: bool, - last_batch: bool, - ) -> anyhow::Result> { - // Check if we need to split the delta layer. We split at the original delta layer boundary to avoid - // overlapping layers. - // - // If we have a structure like this: - // - // | Delta 1 | | Delta 4 | - // |---------| Delta 2 |---------| - // | Delta 3 | | Delta 5 | - // - // And we choose to compact delta 2+3+5. We will get an overlapping delta layer with delta 1+4. - // A simple solution here is to split the delta layers using the original boundary, while this - // might produce a lot of small layers. This should be improved and fixed in the future. - let mut need_split = false; - while *current_delta_split_point < delta_split_points.len() - && last_key >= delta_split_points[*current_delta_split_point] - { - *current_delta_split_point += 1; - need_split = true; - } - if !need_split && !last_batch { - return Ok(None); - } - let deltas: Vec<(Key, Lsn, Value)> = std::mem::take(deltas); - if deltas.is_empty() { - return Ok(None); - } - let end_lsn = deltas.iter().map(|(_, lsn, _)| lsn).max().copied().unwrap() + 1; - let delta_key = PersistentLayerKey { - key_range: { - let key_start = deltas.first().unwrap().0; - let key_end = deltas.last().unwrap().0.next(); - key_start..key_end - }, - lsn_range: lowest_retain_lsn..end_lsn, - is_delta: true, - }; - { - // Hack: skip delta layer if we need to produce a layer of a same key-lsn. - // - // This can happen if we have removed some deltas in "the middle" of some existing layer's key-lsn-range. - // For example, consider the case where a single delta with range [0x10,0x50) exists. - // And we have branches at LSN 0x10, 0x20, 0x30. - // Then we delete branch @ 0x20. - // Bottom-most compaction may now delete the delta [0x20,0x30). - // And that wouldnt' change the shape of the layer. - // - // Note that bottom-most-gc-compaction never _adds_ new data in that case, only removes. - // That's why it's safe to skip. - let guard = tline.layers.read().await; - - if guard.contains_key(&delta_key) { - let layer_generation = guard.get_from_key(&delta_key).metadata().generation; - drop(guard); - if layer_generation == tline.generation { - stats.discard_delta_layer(); - // TODO: depending on whether we design this compaction process to run along with - // other compactions, there could be layer map modifications after we drop the - // layer guard, and in case it creates duplicated layer key, we will still error - // in the end. - info!( - key=%delta_key, - ?layer_generation, - "discard delta layer due to duplicated layer in the same generation" - ); - return Ok(Some(FlushDeltaResult::KeepLayer(delta_key))); - } - } - } - - let mut delta_layer_writer = DeltaLayerWriter::new( - tline.conf, - tline.timeline_id, - tline.tenant_shard_id, - delta_key.key_range.start, - lowest_retain_lsn..end_lsn, - ctx, - ) - .await?; - for (key, lsn, val) in deltas { - delta_layer_writer.put_value(key, lsn, val, ctx).await?; - } - - stats.produce_delta_layer(delta_layer_writer.size()); - if dry_run { - return Ok(None); - } - - let (desc, path) = delta_layer_writer - .finish(delta_key.key_range.end, ctx) - .await?; - let delta_layer = Layer::finish_creating(tline.conf, tline, desc, &path)?; - Ok(Some(FlushDeltaResult::CreateResidentLayer(delta_layer))) - } - - // Hack the key range to be min..(max-1). Otherwise, the image layer will be - // interpreted as an L0 delta layer. - let hack_image_layer_range = { - let mut end_key = Key::MAX; - end_key.field6 -= 1; - Key::MIN..end_key - }; - // Only create image layers when there is no ancestor branches. TODO: create covering image layer // when some condition meet. let mut image_layer_writer = if self.ancestor_timeline.is_none() { Some( - ImageLayerWriter::new( + SplitImageLayerWriter::new( self.conf, self.timeline_id, self.tenant_shard_id, - &hack_image_layer_range, // covers the full key range + Key::MIN, lowest_retain_lsn, + self.get_compaction_target_size(), ctx, ) .await?, @@ -2024,6 +2010,17 @@ impl Timeline { None }; + let mut delta_layer_writer = SplitDeltaLayerWriter::new( + self.conf, + self.timeline_id, + self.tenant_shard_id, + Key::MIN, + lowest_retain_lsn..end_lsn, + self.get_compaction_target_size(), + ctx, + ) + .await?; + /// Returns None if there is no ancestor branch. Throw an error when the key is not found. /// /// Currently, we always get the ancestor image for each key in the child branch no matter whether the image @@ -2044,47 +2041,11 @@ impl Timeline { let img = tline.get(key, tline.ancestor_lsn, ctx).await?; Ok(Some((key, tline.ancestor_lsn, img))) } - let image_layer_key = PersistentLayerKey { - key_range: hack_image_layer_range, - lsn_range: PersistentLayerDesc::image_layer_lsn_range(lowest_retain_lsn), - is_delta: false, - }; - - // Like with delta layers, it can happen that we re-produce an already existing image layer. - // This could happen when a user triggers force compaction and image generation. In this case, - // it's always safe to rewrite the layer. - let discard_image_layer = { - let guard = self.layers.read().await; - if guard.contains_key(&image_layer_key) { - let layer_generation = guard.get_from_key(&image_layer_key).metadata().generation; - drop(guard); - if layer_generation == self.generation { - // TODO: depending on whether we design this compaction process to run along with - // other compactions, there could be layer map modifications after we drop the - // layer guard, and in case it creates duplicated layer key, we will still error - // in the end. - info!( - key=%image_layer_key, - ?layer_generation, - "discard image layer due to duplicated layer key in the same generation", - ); - true - } else { - false - } - } else { - false - } - }; // Actually, we can decide not to write to the image layer at all at this point because // the key and LSN range are determined. However, to keep things simple here, we still // create this writer, and discard the writer in the end. - let mut delta_values = Vec::new(); - let delta_split_points = delta_split_points.into_iter().collect_vec(); - let mut current_delta_split_point = 0; - let mut delta_layers = Vec::new(); while let Some((key, lsn, val)) = merge_iter.next().await? { if cancel.is_cancelled() { return Err(anyhow!("cancelled")); // TODO: refactor to CompactionError and pass cancel error @@ -2115,27 +2076,14 @@ impl Timeline { retention .pipe_to( *last_key, - &mut delta_values, + self, + &mut delta_layer_writer, image_layer_writer.as_mut(), &mut stat, + dry_run, ctx, ) .await?; - delta_layers.extend( - flush_deltas( - &mut delta_values, - *last_key, - &delta_split_points, - &mut current_delta_split_point, - self, - lowest_retain_lsn, - ctx, - &mut stat, - dry_run, - false, - ) - .await?, - ); accumulated_values.clear(); *last_key = key; accumulated_values.push((key, lsn, val)); @@ -2159,43 +2107,75 @@ impl Timeline { retention .pipe_to( last_key, - &mut delta_values, + self, + &mut delta_layer_writer, image_layer_writer.as_mut(), &mut stat, + dry_run, ctx, ) .await?; - delta_layers.extend( - flush_deltas( - &mut delta_values, - last_key, - &delta_split_points, - &mut current_delta_split_point, - self, - lowest_retain_lsn, - ctx, - &mut stat, - dry_run, - true, - ) - .await?, - ); - assert!(delta_values.is_empty(), "unprocessed keys"); - let image_layer = if discard_image_layer { - stat.discard_image_layer(); - None - } else if let Some(writer) = image_layer_writer { - stat.produce_image_layer(writer.size()); + let discard = |key: &PersistentLayerKey| { + let key = key.clone(); + async move { KeyHistoryRetention::discard_key(&key, self, dry_run).await } + }; + + let produced_image_layers = if let Some(writer) = image_layer_writer { if !dry_run { - Some(writer.finish(self, ctx).await?) + writer + .finish_with_discard_fn(self, ctx, hack_end_key, discard) + .await? } else { - None + let (layers, _) = writer.take()?; + assert!(layers.is_empty(), "image layers produced in dry run mode?"); + Vec::new() } } else { - None + Vec::new() }; + let produced_delta_layers = if !dry_run { + delta_layer_writer + .finish_with_discard_fn(self, ctx, hack_end_key, discard) + .await? + } else { + let (layers, _) = delta_layer_writer.take()?; + assert!(layers.is_empty(), "delta layers produced in dry run mode?"); + Vec::new() + }; + + let mut compact_to = Vec::new(); + let mut keep_layers = HashSet::new(); + let produced_delta_layers_len = produced_delta_layers.len(); + let produced_image_layers_len = produced_image_layers.len(); + for action in produced_delta_layers { + match action { + SplitWriterResult::Produced(layer) => { + stat.produce_delta_layer(layer.layer_desc().file_size()); + compact_to.push(layer); + } + SplitWriterResult::Discarded(l) => { + keep_layers.insert(l); + stat.discard_delta_layer(); + } + } + } + for action in produced_image_layers { + match action { + SplitWriterResult::Produced(layer) => { + stat.produce_image_layer(layer.layer_desc().file_size()); + compact_to.push(layer); + } + SplitWriterResult::Discarded(l) => { + keep_layers.insert(l); + stat.discard_image_layer(); + } + } + } + let mut layer_selection = layer_selection; + layer_selection.retain(|x| !keep_layers.contains(&x.layer_desc().key())); + info!( "gc-compaction statistics: {}", serde_json::to_string(&stat)? @@ -2206,28 +2186,11 @@ impl Timeline { } info!( - "produced {} delta layers and {} image layers", - delta_layers.len(), - if image_layer.is_some() { 1 } else { 0 } + "produced {} delta layers and {} image layers, {} layers are kept", + produced_delta_layers_len, + produced_image_layers_len, + layer_selection.len() ); - let mut compact_to = Vec::new(); - let mut keep_layers = HashSet::new(); - for action in delta_layers { - match action { - FlushDeltaResult::CreateResidentLayer(layer) => { - compact_to.push(layer); - } - FlushDeltaResult::KeepLayer(l) => { - keep_layers.insert(l); - } - } - } - if discard_image_layer { - keep_layers.insert(image_layer_key); - } - let mut layer_selection = layer_selection; - layer_selection.retain(|x| !keep_layers.contains(&x.layer_desc().key())); - compact_to.extend(image_layer); // Step 3: Place back to the layer map. { diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index b03dbb092e..dc4118bb4a 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -395,7 +395,7 @@ impl DeleteTimelineFlow { task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), TaskKind::TimelineDeletionWorker, - Some(tenant_shard_id), + tenant_shard_id, Some(timeline_id), "timeline_delete", async move { diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index eaa9c0ff62..2f6cb4d73a 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -60,7 +60,7 @@ impl Timeline { task_mgr::spawn( BACKGROUND_RUNTIME.handle(), TaskKind::Eviction, - Some(self.tenant_shard_id), + self.tenant_shard_id, Some(self.timeline_id), &format!( "layer eviction for {}/{}", diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index b5c577af72..0114473eda 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -27,8 +27,8 @@ use super::TaskStateUpdate; use crate::{ context::RequestContext, metrics::{LIVE_CONNECTIONS, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST}, - task_mgr::TaskKind, - task_mgr::WALRECEIVER_RUNTIME, + pgdatadir_mapping::DatadirModification, + task_mgr::{TaskKind, WALRECEIVER_RUNTIME}, tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo}, walingest::WalIngest, walrecord::DecodedWALRecord, @@ -345,7 +345,10 @@ pub(super) async fn handle_walreceiver_connection( // Commit every ingest_batch_size records. Even if we filtered out // all records, we still need to call commit to advance the LSN. uncommitted_records += 1; - if uncommitted_records >= ingest_batch_size { + if uncommitted_records >= ingest_batch_size + || modification.approx_pending_bytes() + > DatadirModification::MAX_PENDING_BYTES + { WAL_INGEST .records_committed .inc_by(uncommitted_records - filtered_records); diff --git a/pageserver/src/utilization.rs b/pageserver/src/utilization.rs index 3c48c84598..a0223f3bce 100644 --- a/pageserver/src/utilization.rs +++ b/pageserver/src/utilization.rs @@ -9,7 +9,7 @@ use utils::serde_percent::Percent; use pageserver_api::models::PageserverUtilization; -use crate::{config::PageServerConf, tenant::mgr::TenantManager}; +use crate::{config::PageServerConf, metrics::NODE_UTILIZATION_SCORE, tenant::mgr::TenantManager}; pub(crate) fn regenerate( conf: &PageServerConf, @@ -58,13 +58,13 @@ pub(crate) fn regenerate( disk_usable_pct, shard_count, max_shard_count: MAX_SHARDS, - utilization_score: 0, + utilization_score: None, captured_at: utils::serde_system_time::SystemTime(captured_at), }; - doc.refresh_score(); - - // TODO: make utilization_score into a metric + // Initialize `PageserverUtilization::utilization_score` + let score = doc.cached_score(); + NODE_UTILIZATION_SCORE.set(score); Ok(doc) } diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index b4695e5f40..c0017280fd 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -756,11 +756,23 @@ impl VirtualFile { }) } + /// The function aborts the process if the error is fatal. async fn write_at( &self, buf: FullSlice, offset: u64, _ctx: &RequestContext, /* TODO: use for metrics: https://github.com/neondatabase/neon/issues/6107 */ + ) -> (FullSlice, Result) { + let (slice, result) = self.write_at_inner(buf, offset, _ctx).await; + let result = result.maybe_fatal_err("write_at"); + (slice, result) + } + + async fn write_at_inner( + &self, + buf: FullSlice, + offset: u64, + _ctx: &RequestContext, /* TODO: use for metrics: https://github.com/neondatabase/neon/issues/6107 */ ) -> (FullSlice, Result) { let file_guard = match self.lock_file().await { Ok(file_guard) => file_guard, diff --git a/patches/pg_hintplan.patch b/patches/pg_hint_plan.patch similarity index 55% rename from patches/pg_hintplan.patch rename to patches/pg_hint_plan.patch index 61a5ecbb90..4039a036df 100644 --- a/patches/pg_hintplan.patch +++ b/patches/pg_hint_plan.patch @@ -1,13 +1,7 @@ -commit f7925d4d1406c0f0229e3c691c94b69e381899b1 (HEAD -> master) -Author: Alexey Masterov -Date: Thu Jun 6 08:02:42 2024 +0000 - - Patch expected files to consider Neon's log messages - -diff --git a/ext-src/pg_hint_plan-src/expected/ut-A.out b/ext-src/pg_hint_plan-src/expected/ut-A.out -index da723b8..f8d0102 100644 ---- a/ext-src/pg_hint_plan-src/expected/ut-A.out -+++ b/ext-src/pg_hint_plan-src/expected/ut-A.out +diff --git a/expected/ut-A.out b/expected/ut-A.out +index da723b8..5328114 100644 +--- a/expected/ut-A.out ++++ b/expected/ut-A.out @@ -9,13 +9,16 @@ SET search_path TO public; ---- -- No.A-1-1-3 @@ -25,10 +19,18 @@ index da723b8..f8d0102 100644 DROP SCHEMA other_schema; ---- ---- No. A-5-1 comment pattern -diff --git a/ext-src/pg_hint_plan-src/expected/ut-fdw.out b/ext-src/pg_hint_plan-src/expected/ut-fdw.out +@@ -3175,6 +3178,7 @@ SELECT s.query, s.calls + FROM public.pg_stat_statements s + JOIN pg_catalog.pg_database d + ON (s.dbid = d.oid) ++ WHERE s.query LIKE 'SELECT * FROM s1.t1%' OR s.query LIKE '%pg_stat_statements_reset%' + ORDER BY 1; + query | calls + --------------------------------------+------- +diff --git a/expected/ut-fdw.out b/expected/ut-fdw.out index d372459..6282afe 100644 ---- a/ext-src/pg_hint_plan-src/expected/ut-fdw.out -+++ b/ext-src/pg_hint_plan-src/expected/ut-fdw.out +--- a/expected/ut-fdw.out ++++ b/expected/ut-fdw.out @@ -7,6 +7,7 @@ SET pg_hint_plan.debug_print TO on; SET client_min_messages TO LOG; SET pg_hint_plan.enable_hint TO on; @@ -37,3 +39,15 @@ index d372459..6282afe 100644 CREATE SERVER file_server FOREIGN DATA WRAPPER file_fdw; CREATE USER MAPPING FOR PUBLIC SERVER file_server; CREATE FOREIGN TABLE ft1 (id int, val int) SERVER file_server OPTIONS (format 'csv', filename :'filename'); +diff --git a/sql/ut-A.sql b/sql/ut-A.sql +index 7c7d58a..4fd1a07 100644 +--- a/sql/ut-A.sql ++++ b/sql/ut-A.sql +@@ -963,6 +963,7 @@ SELECT s.query, s.calls + FROM public.pg_stat_statements s + JOIN pg_catalog.pg_database d + ON (s.dbid = d.oid) ++ WHERE s.query LIKE 'SELECT * FROM s1.t1%' OR s.query LIKE '%pg_stat_statements_reset%' + ORDER BY 1; + + ---- diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 1894e8c72a..479209a537 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -41,6 +41,8 @@ #include "hll.h" +#define CriticalAssert(cond) do if (!(cond)) elog(PANIC, "Assertion %s failed at %s:%d: ", #cond, __FILE__, __LINE__); while (0) + /* * Local file cache is used to temporary store relations pages in local file system. * All blocks of all relations are stored inside one file and addressed using shared hash map. @@ -51,19 +53,43 @@ * * Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about * its consistency. + + * + * ## Holes + * + * The LFC can be resized on the fly, up to a maximum size that's determined + * at server startup (neon.max_file_cache_size). After server startup, we + * expand the underlying file when needed, until it reaches the soft limit + * (neon.file_cache_size_limit). If the soft limit is later reduced, we shrink + * the LFC by punching holes in the underlying file with a + * fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't + * shrink, but the disk space it uses does. + * + * Each hole is tracked by a dummy FileCacheEntry, which are kept in the + * 'holes' linked list. They are entered into the chunk hash table, with a + * special key where the blockNumber is used to store the 'offset' of the + * hole, and all other fields are zero. Holes are never looked up in the hash + * table, we only enter them there to have a FileCacheEntry that we can keep + * in the linked list. If the soft limit is raised again, we reuse the holes + * before extending the nominal size of the file. */ /* Local file storage allocation chunk. - * Should be power of two and not less than 32. Using larger than page chunks can + * Should be power of two. Using larger than page chunks can * 1. Reduce hash-map memory footprint: 8TB database contains billion pages * and size of hash entry is 40 bytes, so we need 40Gb just for hash map. * 1Mb chunks can reduce hash map size to 320Mb. * 2. Improve access locality, subsequent pages will be allocated together improving seqscan speed */ #define BLOCKS_PER_CHUNK 128 /* 1Mb chunk */ +/* + * Smaller chunk seems to be better for OLTP workload + */ +// #define BLOCKS_PER_CHUNK 8 /* 64kb chunk */ #define MB ((uint64)1024*1024) #define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ / BLOCKS_PER_CHUNK)) +#define CHUNK_BITMAP_SIZE ((BLOCKS_PER_CHUNK + 31) / 32) typedef struct FileCacheEntry { @@ -71,8 +97,8 @@ typedef struct FileCacheEntry uint32 hash; uint32 offset; uint32 access_count; - uint32 bitmap[BLOCKS_PER_CHUNK / 32]; - dlist_node lru_node; /* LRU list node */ + uint32 bitmap[CHUNK_BITMAP_SIZE]; + dlist_node list_node; /* LRU/holes list node */ } FileCacheEntry; typedef struct FileCacheControl @@ -87,6 +113,7 @@ typedef struct FileCacheControl uint64 writes; dlist_head lru; /* double linked list for LRU replacement * algorithm */ + dlist_head holes; /* double linked list of punched holes */ HyperLogLogState wss_estimation; /* estimation of working set size */ } FileCacheControl; @@ -135,6 +162,7 @@ lfc_disable(char const *op) lfc_ctl->used = 0; lfc_ctl->limit = 0; dlist_init(&lfc_ctl->lru); + dlist_init(&lfc_ctl->holes); if (lfc_desc > 0) { @@ -214,18 +242,18 @@ lfc_shmem_startup(void) if (!found) { int fd; - uint32 lfc_size = SIZE_MB_TO_CHUNKS(lfc_max_size); + uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size); lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock"); info.keysize = sizeof(BufferTag); info.entrysize = sizeof(FileCacheEntry); /* - * lfc_size+1 because we add new element to hash table before eviction + * n_chunks+1 because we add new element to hash table before eviction * of victim */ lfc_hash = ShmemInitHash("lfc_hash", - lfc_size + 1, lfc_size + 1, + n_chunks + 1, n_chunks + 1, &info, HASH_ELEM | HASH_BLOBS); lfc_ctl->generation = 0; @@ -235,6 +263,7 @@ lfc_shmem_startup(void) lfc_ctl->misses = 0; lfc_ctl->writes = 0; dlist_init(&lfc_ctl->lru); + dlist_init(&lfc_ctl->holes); /* Initialize hyper-log-log structure for estimating working set size */ initSHLL(&lfc_ctl->wss_estimation); @@ -310,14 +339,31 @@ lfc_change_limit_hook(int newval, void *extra) * Shrink cache by throwing away least recently accessed chunks and * returning their space to file system */ - FileCacheEntry *victim = dlist_container(FileCacheEntry, lru_node, dlist_pop_head_node(&lfc_ctl->lru)); + FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru)); + FileCacheEntry *hole; + uint32 offset = victim->offset; + uint32 hash; + bool found; + BufferTag holetag; - Assert(victim->access_count == 0); + CriticalAssert(victim->access_count == 0); #ifdef FALLOC_FL_PUNCH_HOLE if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * BLOCKS_PER_CHUNK * BLCKSZ, BLOCKS_PER_CHUNK * BLCKSZ) < 0) neon_log(LOG, "Failed to punch hole in file: %m"); #endif + /* We remove the old entry, and re-enter a hole to the hash table */ hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL); + + memset(&holetag, 0, sizeof(holetag)); + holetag.blockNum = offset; + hash = get_hash_value(lfc_hash, &holetag); + hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found); + hole->hash = hash; + hole->offset = offset; + hole->access_count = 0; + CriticalAssert(!found); + dlist_push_tail(&lfc_ctl->holes, &hole->list_node); + lfc_ctl->used -= 1; } lfc_ctl->limit = new_size; @@ -409,6 +455,8 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno) CopyNRelFileInfoToBufTag(tag, rinfo); tag.forkNum = forkNum; tag.blockNum = blkno & ~(BLOCKS_PER_CHUNK - 1); + + CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); hash = get_hash_value(lfc_hash, &tag); LWLockAcquire(lfc_lock, LW_SHARED); @@ -440,6 +488,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno) tag.forkNum = forkNum; tag.blockNum = (blkno & ~(BLOCKS_PER_CHUNK - 1)); + CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); hash = get_hash_value(lfc_hash, &tag); LWLockAcquire(lfc_lock, LW_EXCLUSIVE); @@ -470,7 +519,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno) { bool has_remaining_pages; - for (int i = 0; i < (BLOCKS_PER_CHUNK / 32); i++) + for (int i = 0; i < CHUNK_BITMAP_SIZE; i++) { if (entry->bitmap[i] != 0) { @@ -485,8 +534,8 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno) */ if (!has_remaining_pages) { - dlist_delete(&entry->lru_node); - dlist_push_head(&lfc_ctl->lru, &entry->lru_node); + dlist_delete(&entry->list_node); + dlist_push_head(&lfc_ctl->lru, &entry->list_node); } } @@ -525,6 +574,8 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, CopyNRelFileInfoToBufTag(tag, rinfo); tag.forkNum = forkNum; tag.blockNum = blkno & ~(BLOCKS_PER_CHUNK - 1); + + CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); hash = get_hash_value(lfc_hash, &tag); LWLockAcquire(lfc_lock, LW_EXCLUSIVE); @@ -551,7 +602,7 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, } /* Unlink entry from LRU list to pin it for the duration of IO operation */ if (entry->access_count++ == 0) - dlist_delete(&entry->lru_node); + dlist_delete(&entry->list_node); generation = lfc_ctl->generation; entry_offset = entry->offset; @@ -569,12 +620,12 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, if (lfc_ctl->generation == generation) { - Assert(LFC_ENABLED()); + CriticalAssert(LFC_ENABLED()); lfc_ctl->hits += 1; pgBufferUsage.file_cache.hits += 1; - Assert(entry->access_count > 0); + CriticalAssert(entry->access_count > 0); if (--entry->access_count == 0) - dlist_push_tail(&lfc_ctl->lru, &entry->lru_node); + dlist_push_tail(&lfc_ctl->lru, &entry->list_node); } else result = false; @@ -613,6 +664,8 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, const void tag.forkNum = forkNum; tag.blockNum = blkno & ~(BLOCKS_PER_CHUNK - 1); CopyNRelFileInfoToBufTag(tag, rinfo); + + CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); hash = get_hash_value(lfc_hash, &tag); LWLockAcquire(lfc_lock, LW_EXCLUSIVE); @@ -632,7 +685,7 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, const void * operation */ if (entry->access_count++ == 0) - dlist_delete(&entry->lru_node); + dlist_delete(&entry->list_node); } else { @@ -655,13 +708,26 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, const void if (lfc_ctl->used >= lfc_ctl->limit && !dlist_is_empty(&lfc_ctl->lru)) { /* Cache overflow: evict least recently used chunk */ - FileCacheEntry *victim = dlist_container(FileCacheEntry, lru_node, dlist_pop_head_node(&lfc_ctl->lru)); + FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru)); - Assert(victim->access_count == 0); + CriticalAssert(victim->access_count == 0); entry->offset = victim->offset; /* grab victim's chunk */ hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL); neon_log(DEBUG2, "Swap file cache page"); } + else if (!dlist_is_empty(&lfc_ctl->holes)) + { + /* We can reuse a hole that was left behind when the LFC was shrunk previously */ + FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->holes)); + uint32 offset = hole->offset; + bool found; + + hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &found); + CriticalAssert(found); + + lfc_ctl->used += 1; + entry->offset = offset; /* reuse the hole */ + } else { lfc_ctl->used += 1; @@ -689,11 +755,11 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, const void if (lfc_ctl->generation == generation) { - Assert(LFC_ENABLED()); + CriticalAssert(LFC_ENABLED()); /* Place entry to the head of LRU list */ - Assert(entry->access_count > 0); + CriticalAssert(entry->access_count > 0); if (--entry->access_count == 0) - dlist_push_tail(&lfc_ctl->lru, &entry->lru_node); + dlist_push_tail(&lfc_ctl->lru, &entry->list_node); entry->bitmap[chunk_offs >> 5] |= (1 << (chunk_offs & 31)); } @@ -708,7 +774,6 @@ typedef struct } NeonGetStatsCtx; #define NUM_NEON_GET_STATS_COLS 2 -#define NUM_NEON_GET_STATS_ROWS 3 PG_FUNCTION_INFO_V1(neon_get_lfc_stats); Datum @@ -744,7 +809,6 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS) INT8OID, -1, 0); fctx->tupdesc = BlessTupleDesc(tupledesc); - funcctx->max_calls = NUM_NEON_GET_STATS_ROWS; funcctx->user_fctx = fctx; /* Return to original context when allocating transient memory */ @@ -778,6 +842,11 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS) if (lfc_ctl) value = lfc_ctl->writes; break; + case 4: + key = "file_cache_size"; + if (lfc_ctl) + value = lfc_ctl->size; + break; default: SRF_RETURN_DONE(funcctx); } @@ -901,7 +970,7 @@ local_cache_pages(PG_FUNCTION_ARGS) hash_seq_init(&status, lfc_hash); while ((entry = hash_seq_search(&status)) != NULL) { - for (int i = 0; i < BLOCKS_PER_CHUNK / 32; i++) + for (int i = 0; i < CHUNK_BITMAP_SIZE; i++) n_pages += pg_popcount32(entry->bitmap[i]); } } diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 73a001b6ba..5126c26c5d 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -550,9 +550,6 @@ pageserver_connect(shardno_t shard_no, int elevel) case 2: pagestream_query = psprintf("pagestream_v2 %s %s", neon_tenant, neon_timeline); break; - case 1: - pagestream_query = psprintf("pagestream %s %s", neon_tenant, neon_timeline); - break; default: elog(ERROR, "unexpected neon_protocol_version %d", neon_protocol_version); } @@ -1063,7 +1060,7 @@ pg_init_libpagestore(void) NULL, &neon_protocol_version, 2, /* use protocol version 2 */ - 1, /* min */ + 2, /* min */ 2, /* max */ PGC_SU_BACKEND, 0, /* no flags required */ diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index f19732cbbb..addb6ccce6 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -54,6 +54,10 @@ #define BufTagGetNRelFileInfo(tag) tag.rnode +#define BufTagGetRelNumber(tagp) ((tagp)->rnode.relNode) + +#define InvalidRelFileNumber InvalidOid + #define SMgrRelGetRelInfo(reln) \ (reln->smgr_rnode.node) diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 8951e6607b..1f196d016c 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -87,9 +87,8 @@ typedef enum { * can skip traversing through recent layers which we know to not contain any * versions for the requested page. * - * These structs describe the V2 of these requests. The old V1 protocol contained - * just one LSN and a boolean 'latest' flag. If the neon_protocol_version GUC is - * set to 1, we will convert these to the V1 requests before sending. + * These structs describe the V2 of these requests. (The old now-defunct V1 + * protocol contained just one LSN and a boolean 'latest' flag.) */ typedef struct { diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 8edaf65639..7f39c7d026 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -1001,51 +1001,10 @@ nm_pack_request(NeonRequest *msg) initStringInfo(&s); - if (neon_protocol_version >= 2) - { - pq_sendbyte(&s, msg->tag); - pq_sendint64(&s, msg->lsn); - pq_sendint64(&s, msg->not_modified_since); - } - else - { - bool latest; - XLogRecPtr lsn; + pq_sendbyte(&s, msg->tag); + pq_sendint64(&s, msg->lsn); + pq_sendint64(&s, msg->not_modified_since); - /* - * In primary, we always request the latest page version. - */ - if (!RecoveryInProgress()) - { - latest = true; - lsn = msg->not_modified_since; - } - else - { - /* - * In the protocol V1, we cannot represent that we want to read - * page at LSN X, and we know that it hasn't been modified since - * Y. We can either use 'not_modified_lsn' as the request LSN, and - * risk getting an error if that LSN is too old and has already - * fallen out of the pageserver's GC horizon, or we can send - * 'request_lsn', causing the pageserver to possibly wait for the - * recent WAL to arrive unnecessarily. Or something in between. We - * choose to use the old LSN and risk GC errors, because that's - * what we've done historically. - */ - latest = false; - lsn = msg->not_modified_since; - } - - pq_sendbyte(&s, msg->tag); - pq_sendbyte(&s, latest); - pq_sendint64(&s, lsn); - } - - /* - * The rest of the request messages are the same between protocol V1 and - * V2 - */ switch (messageTag(msg)) { /* pagestore_client -> pagestore */ diff --git a/pgxn/neon/relsize_cache.c b/pgxn/neon/relsize_cache.c index cc7ac2c394..2a4c2dc799 100644 --- a/pgxn/neon/relsize_cache.c +++ b/pgxn/neon/relsize_cache.c @@ -110,7 +110,8 @@ get_cached_relsize(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber *size) tag.rinfo = rinfo; tag.forknum = forknum; - LWLockAcquire(relsize_lock, LW_SHARED); + /* We need exclusive lock here because of LRU list manipulation */ + LWLockAcquire(relsize_lock, LW_EXCLUSIVE); entry = hash_search(relsize_hash, &tag, HASH_FIND, NULL); if (entry != NULL) { diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index f3ddc64061..65ef588ba5 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -220,6 +220,64 @@ nwp_register_gucs(void) NULL, NULL, NULL); } + +static int +split_safekeepers_list(char *safekeepers_list, char *safekeepers[]) +{ + int n_safekeepers = 0; + char *curr_sk = safekeepers_list; + + for (char *coma = safekeepers_list; coma != NULL && *coma != '\0'; curr_sk = coma) + { + if (++n_safekeepers >= MAX_SAFEKEEPERS) { + wpg_log(FATAL, "too many safekeepers"); + } + + coma = strchr(coma, ','); + safekeepers[n_safekeepers-1] = curr_sk; + + if (coma != NULL) { + *coma++ = '\0'; + } + } + + return n_safekeepers; +} + +/* + * Accept two coma-separated strings with list of safekeeper host:port addresses. + * Split them into arrays and return false if two sets do not match, ignoring the order. + */ +static bool +safekeepers_cmp(char *old, char *new) +{ + char *safekeepers_old[MAX_SAFEKEEPERS]; + char *safekeepers_new[MAX_SAFEKEEPERS]; + int len_old = 0; + int len_new = 0; + + len_old = split_safekeepers_list(old, safekeepers_old); + len_new = split_safekeepers_list(new, safekeepers_new); + + if (len_old != len_new) + { + return false; + } + + qsort(&safekeepers_old, len_old, sizeof(char *), pg_qsort_strcmp); + qsort(&safekeepers_new, len_new, sizeof(char *), pg_qsort_strcmp); + + for (int i = 0; i < len_new; i++) + { + if (strcmp(safekeepers_old[i], safekeepers_new[i]) != 0) + { + return false; + } + } + + return true; +} + /* * GUC assign_hook for neon.safekeepers. Restarts walproposer through FATAL if * the list changed. @@ -235,19 +293,26 @@ assign_neon_safekeepers(const char *newval, void *extra) wpg_log(FATAL, "neon.safekeepers is empty"); } + /* Copy values because we will modify them in split_safekeepers_list() */ + char *newval_copy = pstrdup(newval); + char *oldval = pstrdup(wal_acceptors_list); + /* * TODO: restarting through FATAL is stupid and introduces 1s delay before * next bgw start. We should refactor walproposer to allow graceful exit and * thus remove this delay. + * XXX: If you change anything here, sync with test_safekeepers_reconfigure_reorder. */ - if (strcmp(wal_acceptors_list, newval) != 0) + if (!safekeepers_cmp(oldval, newval_copy)) { wpg_log(FATAL, "restarting walproposer to change safekeeper list from %s to %s", wal_acceptors_list, newval); } + pfree(newval_copy); + pfree(oldval); } -/* Check if we need to suspend inserts because of lagging replication. */ +/* Check if we need to suspend inserts because of lagging replication. */ static uint64 backpressure_lag_impl(void) { diff --git a/pre-commit.py b/pre-commit.py index c5ed63ac44..ae432e8225 100755 --- a/pre-commit.py +++ b/pre-commit.py @@ -2,6 +2,7 @@ import argparse import enum +import os import subprocess import sys from typing import List @@ -93,7 +94,7 @@ if __name__ == "__main__": "--no-color", action="store_true", help="disable colored output", - default=not sys.stdout.isatty(), + default=not sys.stdout.isatty() or os.getenv("TERM") == "dumb", ) args = parser.parse_args() diff --git a/proxy/README.md b/proxy/README.md index d1f2e3f27b..8d850737be 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -6,7 +6,7 @@ Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme a new SCRAM-based console API; uses SNI info to select the destination project (endpoint soon) * postgres uses postgres to select auth secrets of existing roles. Useful for local testing -* link +* web (or link) sends login link for all usernames Also proxy can expose following services to the external world: @@ -36,7 +36,7 @@ To play with it locally one may start proxy over a local postgres installation ``` If both postgres and proxy are running you may send a SQL query: -```json +```console curl -k -X POST 'https://proxy.localtest.me:4444/sql' \ -H 'Neon-Connection-String: postgres://stas:pass@proxy.localtest.me:4444/postgres' \ -H 'Content-Type: application/json' \ @@ -44,7 +44,8 @@ curl -k -X POST 'https://proxy.localtest.me:4444/sql' \ "query":"SELECT $1::int[] as arr, $2::jsonb as obj, 42 as num", "params":[ "{{1,2},{\"3\",4}}", {"key":"val", "ikey":4242}] }' | jq - +``` +```json { "command": "SELECT", "fields": [ diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 8c44823c98..7c408f817c 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,20 +1,20 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::BackendType; +pub use backend::Backend; mod credentials; -pub use credentials::{ +pub(crate) use credentials::{ check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, ComputeUserInfoParseError, IpPattern, }; mod password_hack; -pub use password_hack::parse_endpoint_param; +pub(crate) use password_hack::parse_endpoint_param; use password_hack::PasswordHackPayload; mod flow; -pub use flow::*; +pub(crate) use flow::*; use tokio::time::error::Elapsed; use crate::{ @@ -25,13 +25,13 @@ use std::{io, net::IpAddr}; use thiserror::Error; /// Convenience wrapper for the authentication error. -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; /// Common authentication error. #[derive(Debug, Error)] -pub enum AuthErrorImpl { +pub(crate) enum AuthErrorImpl { #[error(transparent)] - Link(#[from] backend::LinkAuthError), + Web(#[from] backend::WebAuthError), #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), @@ -77,30 +77,30 @@ pub enum AuthErrorImpl { #[derive(Debug, Error)] #[error(transparent)] -pub struct AuthError(Box); +pub(crate) struct AuthError(Box); impl AuthError { - pub fn bad_auth_method(name: impl Into>) -> Self { + pub(crate) fn bad_auth_method(name: impl Into>) -> Self { AuthErrorImpl::BadAuthMethod(name.into()).into() } - pub fn auth_failed(user: impl Into>) -> Self { + pub(crate) fn auth_failed(user: impl Into>) -> Self { AuthErrorImpl::AuthFailed(user.into()).into() } - pub fn ip_address_not_allowed(ip: IpAddr) -> Self { + pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self { AuthErrorImpl::IpAddressNotAllowed(ip).into() } - pub fn too_many_connections() -> Self { + pub(crate) fn too_many_connections() -> Self { AuthErrorImpl::TooManyConnections.into() } - pub fn is_auth_failed(&self) -> bool { + pub(crate) fn is_auth_failed(&self) -> bool { matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) } - pub fn user_timeout(elapsed: Elapsed) -> Self { + pub(crate) fn user_timeout(elapsed: Elapsed) -> Self { AuthErrorImpl::UserTimeout(elapsed).into() } } @@ -113,38 +113,36 @@ impl> From for AuthError { impl UserFacingError for AuthError { fn to_string_client(&self) -> String { - use AuthErrorImpl::*; match self.0.as_ref() { - Link(e) => e.to_string_client(), - GetAuthInfo(e) => e.to_string_client(), - Sasl(e) => e.to_string_client(), - AuthFailed(_) => self.to_string(), - BadAuthMethod(_) => self.to_string(), - MalformedPassword(_) => self.to_string(), - MissingEndpointName => self.to_string(), - Io(_) => "Internal error".to_string(), - IpAddressNotAllowed(_) => self.to_string(), - TooManyConnections => self.to_string(), - UserTimeout(_) => self.to_string(), + AuthErrorImpl::Web(e) => e.to_string_client(), + AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(), + AuthErrorImpl::Sasl(e) => e.to_string_client(), + AuthErrorImpl::AuthFailed(_) => self.to_string(), + AuthErrorImpl::BadAuthMethod(_) => self.to_string(), + AuthErrorImpl::MalformedPassword(_) => self.to_string(), + AuthErrorImpl::MissingEndpointName => self.to_string(), + AuthErrorImpl::Io(_) => "Internal error".to_string(), + AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(), + AuthErrorImpl::TooManyConnections => self.to_string(), + AuthErrorImpl::UserTimeout(_) => self.to_string(), } } } impl ReportableError for AuthError { fn get_error_kind(&self) -> crate::error::ErrorKind { - use AuthErrorImpl::*; match self.0.as_ref() { - Link(e) => e.get_error_kind(), - GetAuthInfo(e) => e.get_error_kind(), - Sasl(e) => e.get_error_kind(), - AuthFailed(_) => crate::error::ErrorKind::User, - BadAuthMethod(_) => crate::error::ErrorKind::User, - MalformedPassword(_) => crate::error::ErrorKind::User, - MissingEndpointName => crate::error::ErrorKind::User, - Io(_) => crate::error::ErrorKind::ClientDisconnect, - IpAddressNotAllowed(_) => crate::error::ErrorKind::User, - TooManyConnections => crate::error::ErrorKind::RateLimit, - UserTimeout(_) => crate::error::ErrorKind::User, + AuthErrorImpl::Web(e) => e.get_error_kind(), + AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(), + AuthErrorImpl::Sasl(e) => e.get_error_kind(), + AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User, + AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User, + AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User, + AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User, + AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect, + AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User, + AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit, + AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User, } } } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index c6a0b2af5a..1d28c6df31 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -1,17 +1,19 @@ mod classic; mod hacks; pub mod jwt; -mod link; +pub mod local; +mod web; use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; use ipnet::{Ipv4Net, Ipv6Net}; -pub use link::LinkAuthError; +use local::LocalBackend; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; use tracing::{info, warn}; +pub(crate) use web::WebAuthError; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::{validate_password_and_exchange, AuthError}; @@ -63,26 +65,27 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T, D> { +pub enum Backend<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>, D), + Web(MaybeOwned<'a, url::ApiUrl>, D), + /// Local proxy uses configured auth credentials and does not wake compute + Local(MaybeOwned<'a, LocalBackend>), } -pub trait TestBackend: Send + Sync + 'static { +#[cfg(test)] +pub(crate) trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; fn get_allowed_ips_and_secret( &self, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; - fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, (), ()> { +impl std::fmt::Display for Backend<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use BackendType::*; match self { - Console(api, _) => match &**api { + Self::Console(api, ()) => match &**api { ConsoleBackend::Console(endpoint) => { fmt.debug_tuple("Console").field(&endpoint.url()).finish() } @@ -93,74 +96,76 @@ impl std::fmt::Display for BackendType<'_, (), ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(), + Self::Local(_) => fmt.debug_tuple("Local").finish(), } } } -impl BackendType<'_, T, D> { +impl Backend<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T, &D> { - use BackendType::*; + pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> { match self { - Console(c, x) => Console(MaybeOwned::Borrowed(c), x), - Link(c, x) => Link(MaybeOwned::Borrowed(c), x), + Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x), + Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x), + Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)), } } } -impl<'a, T, D> BackendType<'a, T, D> { +impl<'a, T, D> Backend<'a, T, D> { /// Very similar to [`std::option::Option::map`]. - /// Maps [`BackendType`] to [`BackendType`] by applying + /// Maps [`Backend`] to [`Backend`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { - use BackendType::*; + pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> { match self { - Console(c, x) => Console(c, f(x)), - Link(c, x) => Link(c, x), + Self::Console(c, x) => Backend::Console(c, f(x)), + Self::Web(c, x) => Backend::Web(c, x), + Self::Local(l) => Backend::Local(l), } } } -impl<'a, T, D, E> BackendType<'a, Result, D> { +impl<'a, T, D, E> Backend<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { - use BackendType::*; + pub(crate) fn transpose(self) -> Result, E> { match self { - Console(c, x) => x.map(|x| Console(c, x)), - Link(c, x) => Ok(Link(c, x)), + Self::Console(c, x) => x.map(|x| Backend::Console(c, x)), + Self::Web(c, x) => Ok(Backend::Web(c, x)), + Self::Local(l) => Ok(Backend::Local(l)), } } } -pub struct ComputeCredentials { - pub info: ComputeUserInfo, - pub keys: ComputeCredentialKeys, +pub(crate) struct ComputeCredentials { + pub(crate) info: ComputeUserInfo, + pub(crate) keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] -pub struct ComputeUserInfoNoEndpoint { - pub user: RoleName, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfoNoEndpoint { + pub(crate) user: RoleName, + pub(crate) options: NeonOptions, } #[derive(Debug, Clone)] -pub struct ComputeUserInfo { - pub endpoint: EndpointId, - pub user: RoleName, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfo { + pub(crate) endpoint: EndpointId, + pub(crate) user: RoleName, + pub(crate) options: NeonOptions, } impl ComputeUserInfo { - pub fn endpoint_cache_key(&self) -> EndpointCacheKey { + pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey { self.options.get_cache_key(&self.endpoint) } } -pub enum ComputeCredentialKeys { +pub(crate) enum ComputeCredentialKeys { Password(Vec), AuthKeys(AuthKeys), + None, } impl TryFrom for ComputeUserInfo { @@ -217,7 +222,7 @@ impl RateBucketInfo { } impl AuthenticationConfig { - pub fn check_rate_limit( + pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, config: &AuthenticationConfig, @@ -293,7 +298,9 @@ async fn auth_quirks( ctx.set_endpoint_id(res.info.endpoint.clone()); let password = match res.keys { ComputeCredentialKeys::Password(p) => p, - _ => unreachable!("password hack should return a password"), + ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => { + unreachable!("password hack should return a password") + } }; (res.info, Some(password)) } @@ -317,21 +324,20 @@ async fn auth_quirks( }; let (cached_entry, secret) = cached_secret.take_value(); - let secret = match secret { - Some(secret) => config.check_rate_limit( + let secret = if let Some(secret) = secret { + config.check_rate_limit( ctx, config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, - )?, - None => { - // If we don't have an authentication secret, we mock one to - // prevent malicious probing (possible due to missing protocol steps). - // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); - AuthSecret::Scram(scram::ServerSecret::mock(rand::random())) - } + )? + } else { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthSecret::Scram(scram::ServerSecret::mock(rand::random())) }; match authenticate_with_secret( @@ -397,41 +403,28 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { - /// Get compute endpoint name from the credentials. - pub fn get_endpoint(&self) -> Option { - use BackendType::*; - - match self { - Console(_, user_info) => user_info.endpoint_id.clone(), - Link(_, _) => Some("link".into()), - } - } - +impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get username from the credentials. - pub fn get_user(&self) -> &str { - use BackendType::*; - + pub(crate) fn get_user(&self) -> &str { match self { - Console(_, user_info) => &user_info.user, - Link(_, _) => "link", + Self::Console(_, user_info) => &user_info.user, + Self::Web(_, ()) => "web", + Self::Local(_) => "local", } } /// Authenticate the client via the requested backend, possibly using credentials. #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] - pub async fn authenticate( + pub(crate) async fn authenticate( self, ctx: &RequestMonitoring, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result> { - use BackendType::*; - + ) -> auth::Result> { let res = match self { - Console(api, user_info) => { + Self::Console(api, user_info) => { info!( user = &*user_info.user, project = user_info.endpoint(), @@ -448,15 +441,18 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { endpoint_rate_limiter, ) .await?; - BackendType::Console(api, credentials) + Backend::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Link(url, _) => { - info!("performing link authentication"); + Self::Web(url, ()) => { + info!("performing web authentication"); - let info = link::authenticate(ctx, &url, client).await?; + let info = web::authenticate(ctx, &url, client).await?; - BackendType::Link(url, info) + Backend::Web(url, info) + } + Self::Local(_) => { + return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) } }; @@ -465,70 +461,72 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { } } -impl BackendType<'_, ComputeUserInfo, &()> { - pub async fn get_role_secret( +impl Backend<'_, ComputeUserInfo, &()> { + pub(crate) async fn get_role_secret( &self, ctx: &RequestMonitoring, ) -> Result { - use BackendType::*; match self { - Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Link(_, _) => Ok(Cached::new_uncached(None)), + Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await, + Self::Web(_, ()) => Ok(Cached::new_uncached(None)), + Self::Local(_) => Ok(Cached::new_uncached(None)), } } - pub async fn get_allowed_ips_and_secret( + pub(crate) async fn get_allowed_ips_and_secret( &self, ctx: &RequestMonitoring, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { - use BackendType::*; match self { - Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), } } } #[async_trait::async_trait] -impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { +impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result { - use BackendType::*; - match self { - Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Link(_, info) => Ok(Cached::new_uncached(info.clone())), + Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())), + Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + fn get_keys(&self) -> &ComputeCredentialKeys { match self { - BackendType::Console(_, creds) => Some(&creds.keys), - BackendType::Link(_, _) => None, + Self::Console(_, creds) => &creds.keys, + Self::Web(_, _) => &ComputeCredentialKeys::None, + Self::Local(_) => &ComputeCredentialKeys::None, } } } #[async_trait::async_trait] -impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { +impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result { - use BackendType::*; - match self { - Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Self::Web(_, ()) => { + unreachable!("web auth flow doesn't support waking the compute") + } + Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + fn get_keys(&self) -> &ComputeCredentialKeys { match self { - BackendType::Console(_, creds) => Some(&creds.keys), - BackendType::Link(_, _) => None, + Self::Console(_, creds) => &creds.keys, + Self::Web(_, ()) => &ComputeCredentialKeys::None, + Self::Local(_) => &ComputeCredentialKeys::None, } } } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 56921dd949..e9019ce2cf 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -17,7 +17,7 @@ use tracing::{info, warn}; /// one round trip and *expensive* computations (>= 4096 HMAC iterations). /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. -pub async fn authenticate_cleartext( +pub(crate) async fn authenticate_cleartext( ctx: &RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, @@ -59,7 +59,7 @@ pub async fn authenticate_cleartext( /// Workaround for clients which don't provide an endpoint (project) name. /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) -pub async fn password_hack_no_authentication( +pub(crate) async fn password_hack_no_authentication( ctx: &RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 0c2ca8fb97..e98da82053 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -1,49 +1,81 @@ -use std::{future::Future, sync::Arc, time::Duration}; +use std::{ + future::Future, + sync::Arc, + time::{Duration, SystemTime}, +}; use anyhow::{bail, ensure, Context}; use arc_swap::ArcSwapOption; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; +use serde::{Deserialize, Deserializer}; use signature::Verifier; use tokio::time::Instant; -use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt}; +use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName}; // TODO(conrad): make these configurable. +const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30); const MIN_RENEW: Duration = Duration::from_secs(30); const AUTO_RENEW: Duration = Duration::from_secs(300); const MAX_RENEW: Duration = Duration::from_secs(3600); const MAX_JWK_BODY_SIZE: usize = 64 * 1024; /// How to get the JWT auth rules -pub trait FetchAuthRules: Clone + Send + Sync + 'static { - fn fetch_auth_rules(&self) -> impl Future> + Send; +pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static { + fn fetch_auth_rules( + &self, + role_name: RoleName, + ) -> impl Future>> + Send; } -#[derive(Clone)] -struct FetchAuthRulesFromCplane { - #[allow(dead_code)] - endpoint: EndpointIdInt, -} - -impl FetchAuthRules for FetchAuthRulesFromCplane { - async fn fetch_auth_rules(&self) -> anyhow::Result { - Err(anyhow::anyhow!("not yet implemented")) - } -} - -pub struct AuthRules { - jwks_urls: Vec, +pub(crate) struct AuthRule { + pub(crate) id: String, + pub(crate) jwks_url: url::Url, + pub(crate) audience: Option, } #[derive(Default)] -pub struct JwkCache { +pub(crate) struct JwkCache { client: reqwest::Client, - map: DashMap>, + map: DashMap<(EndpointId, RoleName), Arc>, } -pub struct JwkCacheEntryLock { +pub(crate) struct JwkCacheEntry { + /// Should refetch at least every hour to verify when old keys have been removed. + /// Should refetch when new key IDs are seen only every 5 minutes or so + last_retrieved: Instant, + + /// cplane will return multiple JWKs urls that we need to scrape. + key_sets: ahash::HashMap, +} + +impl JwkCacheEntry { + fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> { + self.key_sets.values().find_map(|key_set| { + key_set + .find_key(key_id) + .map(|jwk| (jwk, key_set.audience.as_deref())) + }) + } +} + +struct KeySet { + jwks: jose_jwk::JwkSet, + audience: Option, +} + +impl KeySet { + fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> { + self.jwks + .keys + .iter() + .find(|jwk| jwk.prm.kid.as_deref() == Some(key_id)) + } +} + +pub(crate) struct JwkCacheEntryLock { cached: ArcSwapOption, lookup: tokio::sync::Semaphore, } @@ -57,15 +89,6 @@ impl Default for JwkCacheEntryLock { } } -pub struct JwkCacheEntry { - /// Should refetch at least every hour to verify when old keys have been removed. - /// Should refetch when new key IDs are seen only every 5 minutes or so - last_retrieved: Instant, - - /// cplane will return multiple JWKs urls that we need to scrape. - key_sets: ahash::HashMap, -} - impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -79,6 +102,7 @@ impl JwkCacheEntryLock { &self, _permit: JwkRenewalPermit<'_>, client: &reqwest::Client, + role_name: RoleName, auth_rules: &F, ) -> anyhow::Result> { // double check that no one beat us to updating the cache. @@ -91,20 +115,19 @@ impl JwkCacheEntryLock { } } - let rules = auth_rules.fetch_auth_rules().await?; - let mut key_sets = ahash::HashMap::with_capacity_and_hasher( - rules.jwks_urls.len(), - ahash::RandomState::new(), - ); + let rules = auth_rules.fetch_auth_rules(role_name).await?; + let mut key_sets = + ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new()); // TODO(conrad): run concurrently // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284) - for url in rules.jwks_urls { - let req = client.get(url.clone()); + for rule in rules { + let req = client.get(rule.jwks_url.clone()); // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. + // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. match req.send().await.and_then(|r| r.error_for_status()) { // todo: should we re-insert JWKs if we want to keep this JWKs URL? // I expect these failures would be quite sparse. - Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"), + Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), Ok(r) => { let resp: http::Response = r.into(); match parse_json_body_with_limit::( @@ -113,9 +136,17 @@ impl JwkCacheEntryLock { ) .await { - Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"), + Err(e) => { + tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); + } Ok(jwks) => { - key_sets.insert(url, jwks); + key_sets.insert( + rule.id, + KeySet { + jwks, + audience: rule.audience, + }, + ); } } } @@ -133,7 +164,9 @@ impl JwkCacheEntryLock { async fn get_or_update_jwk_cache( self: &Arc, + ctx: &RequestMonitoring, client: &reqwest::Client, + role_name: RoleName, fetch: &F, ) -> Result, anyhow::Error> { let now = Instant::now(); @@ -141,18 +174,20 @@ impl JwkCacheEntryLock { // if we have no cached JWKs, try and get some let Some(cached) = guard else { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let permit = self.acquire_permit().await; - return self.renew_jwks(permit, client, fetch).await; + return self.renew_jwks(permit, client, role_name, fetch).await; }; let last_update = now.duration_since(cached.last_retrieved); // check if the cached JWKs need updating. if last_update > MAX_RENEW { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let permit = self.acquire_permit().await; // it's been too long since we checked the keys. wait for them to update. - return self.renew_jwks(permit, client, fetch).await; + return self.renew_jwks(permit, client, role_name, fetch).await; } // every 5 minutes we should spawn a job to eagerly update the token. @@ -164,7 +199,7 @@ impl JwkCacheEntryLock { let client = client.clone(); let fetch = fetch.clone(); tokio::spawn(async move { - if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await { + if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await { tracing::warn!(error=?e, "could not fetch JWKs in background job"); } }); @@ -178,8 +213,10 @@ impl JwkCacheEntryLock { async fn check_jwt( self: &Arc, - jwt: String, + ctx: &RequestMonitoring, + jwt: &str, client: &reqwest::Client, + role_name: RoleName, fetch: &F, ) -> Result<(), anyhow::Error> { // JWT compact form is defined to be @@ -187,38 +224,38 @@ impl JwkCacheEntryLock { // where Signature = alg( || . || ); let (header_payload, signature) = jwt - .rsplit_once(".") + .rsplit_once('.') .context("Provided authentication token is not a valid JWT encoding")?; - let (header, _payload) = header_payload - .split_once(".") + let (header, payload) = header_payload + .split_once('.') .context("Provided authentication token is not a valid JWT encoding")?; let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; - let header = serde_json::from_slice::(&header) + let header = serde_json::from_slice::>(&header) .context("Provided authentication token is not a valid JWT encoding")?; let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; ensure!(header.typ == "JWT"); - let kid = header.kid.context("missing key id")?; + let kid = header.key_id.context("missing key id")?; - let mut guard = self.get_or_update_jwk_cache(client, fetch).await?; + let mut guard = self + .get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch) + .await?; // get the key from the JWKs if possible. If not, wait for the keys to update. - let jwk = loop { - let jwk = guard - .key_sets - .values() - .flat_map(|jwks| &jwks.keys) - .find(|jwk| jwk.prm.kid.as_deref() == Some(kid)); - - match jwk { + let (jwk, expected_audience) = loop { + match guard.find_jwk_and_audience(kid) { Some(jwk) => break jwk, None if guard.last_retrieved.elapsed() > MIN_RENEW => { + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let permit = self.acquire_permit().await; - guard = self.renew_jwks(permit, client, fetch).await?; + guard = self + .renew_jwks(permit, client, role_name.clone(), fetch) + .await?; } _ => { bail!("jwk not found"); @@ -227,7 +264,7 @@ impl JwkCacheEntryLock { }; ensure!( - jwk.is_supported(&header.alg), + jwk.is_supported(&header.algorithm), "signature algorithm not supported" ); @@ -241,31 +278,57 @@ impl JwkCacheEntryLock { key => bail!("unsupported key type {key:?}"), }; - // TODO(conrad): verify iss, exp, nbf, etc... + let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) + .context("Provided authentication token is not a valid JWT encoding")?; + let payload = serde_json::from_slice::>(&payload) + .context("Provided authentication token is not a valid JWT encoding")?; + + tracing::debug!(?payload, "JWT signature valid with claims"); + + match (expected_audience, payload.audience) { + // check the audience matches + (Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"), + // the audience is expected but is missing + (Some(_), None) => bail!("invalid JWT token audience"), + // we don't care for the audience field + (None, _) => {} + } + + let now = SystemTime::now(); + + if let Some(exp) = payload.expiration { + ensure!(now < exp + CLOCK_SKEW_LEEWAY); + } + + if let Some(nbf) = payload.not_before { + ensure!(nbf < now + CLOCK_SKEW_LEEWAY); + } Ok(()) } } impl JwkCache { - pub async fn check_jwt( + pub(crate) async fn check_jwt( &self, - endpoint: EndpointIdInt, - jwt: String, + ctx: &RequestMonitoring, + endpoint: EndpointId, + role_name: RoleName, + fetch: &F, + jwt: &str, ) -> Result<(), anyhow::Error> { // try with just a read lock first - let entry = self.map.get(&endpoint).as_deref().map(Arc::clone); - let entry = match entry { - Some(entry) => entry, - None => { - // acquire a write lock after to insert. - let entry = self.map.entry(endpoint).or_default(); - Arc::clone(&*entry) - } - }; + let key = (endpoint, role_name.clone()); + let entry = self.map.get(&key).as_deref().map(Arc::clone); + let entry = entry.unwrap_or_else(|| { + // acquire a write lock after to insert. + let entry = self.map.entry(key).or_default(); + Arc::clone(&*entry) + }); - let fetch = FetchAuthRulesFromCplane { endpoint }; - entry.check_jwt(jwt, &self.client, &fetch).await + entry + .check_jwt(ctx, jwt, &self.client, role_name, fetch) + .await } } @@ -315,13 +378,49 @@ fn verify_rsa_signature( /// #[derive(serde::Deserialize, serde::Serialize)] -struct JWTHeader<'a> { +struct JwtHeader<'a> { /// must be "JWT" + #[serde(rename = "typ")] typ: &'a str, /// must be a supported alg - alg: jose_jwa::Algorithm, + #[serde(rename = "alg")] + algorithm: jose_jwa::Algorithm, /// key id, must be provided for our usecase - kid: Option<&'a str>, + #[serde(rename = "kid")] + key_id: Option<&'a str>, +} + +/// +#[derive(serde::Deserialize, serde::Serialize, Debug)] +struct JwtPayload<'a> { + /// Audience - Recipient for which the JWT is intended + #[serde(rename = "aud")] + audience: Option<&'a str>, + /// Expiration - Time after which the JWT expires + #[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)] + expiration: Option, + /// Not before - Time after which the JWT expires + #[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)] + not_before: Option, + + // the following entries are only extracted for the sake of debug logging. + /// Issuer of the JWT + #[serde(rename = "iss")] + issuer: Option<&'a str>, + /// Subject of the JWT (the user) + #[serde(rename = "sub")] + subject: Option<&'a str>, + /// Unique token identifier + #[serde(rename = "jti")] + jwt_id: Option<&'a str>, + /// Unique session identifier + #[serde(rename = "sid")] + session_id: Option<&'a str>, +} + +fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let d = >::deserialize(d)?; + Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n))) } struct JwkRenewalPermit<'a> { @@ -340,7 +439,7 @@ impl JwkRenewalPermit<'_> { } } - async fn acquire_permit(from: &Arc) -> JwkRenewalPermit { + async fn acquire_permit(from: &Arc) -> JwkRenewalPermit<'_> { match from.lookup.acquire().await { Ok(permit) => { permit.forget(); @@ -352,7 +451,7 @@ impl JwkRenewalPermit<'_> { } } - fn try_acquire_permit(from: &Arc) -> Option { + fn try_acquire_permit(from: &Arc) -> Option> { match from.lookup.try_acquire() { Ok(permit) => { permit.forget(); @@ -388,6 +487,8 @@ impl Drop for JwkRenewalPermit<'_> { #[cfg(test)] mod tests { + use crate::RoleName; + use super::*; use std::{future::IntoFuture, net::SocketAddr, time::SystemTime}; @@ -431,10 +532,10 @@ mod tests { } fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String { - let header = JWTHeader { + let header = JwtHeader { typ: "JWT", - alg: jose_jwa::Algorithm::Signing(sig), - kid: Some(&kid), + algorithm: jose_jwa::Algorithm::Signing(sig), + key_id: Some(&kid), }; let body = typed_json::json! {{ "exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600, @@ -524,33 +625,40 @@ mod tests { struct Fetch(SocketAddr); impl FetchAuthRules for Fetch { - async fn fetch_auth_rules(&self) -> anyhow::Result { - Ok(AuthRules { - jwks_urls: vec![ - format!("http://{}/foo", self.0).parse().unwrap(), - format!("http://{}/bar", self.0).parse().unwrap(), - ], - }) + async fn fetch_auth_rules( + &self, + _role_name: RoleName, + ) -> anyhow::Result> { + Ok(vec![ + AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{}/foo", self.0).parse().unwrap(), + audience: None, + }, + AuthRule { + id: "bar".to_owned(), + jwks_url: format!("http://{}/bar", self.0).parse().unwrap(), + audience: None, + }, + ]) } } + let role_name = RoleName::from("user"); + let jwk_cache = Arc::new(JwkCacheEntryLock::default()); - jwk_cache - .check_jwt(jwt1, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt2, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt3, &client, &Fetch(addr)) - .await - .unwrap(); - jwk_cache - .check_jwt(jwt4, &client, &Fetch(addr)) - .await - .unwrap(); + for token in [jwt1, jwt2, jwt3, jwt4] { + jwk_cache + .check_jwt( + &RequestMonitoring::test(), + &token, + &client, + role_name.clone(), + &Fetch(addr), + ) + .await + .unwrap(); + } } } diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs new file mode 100644 index 0000000000..8124f568cf --- /dev/null +++ b/proxy/src/auth/backend/local.rs @@ -0,0 +1,77 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use anyhow::Context; +use arc_swap::ArcSwapOption; + +use crate::{ + compute::ConnCfg, + console::{ + messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}, + NodeInfo, + }, + intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag}, + RoleName, +}; + +use super::jwt::{AuthRule, FetchAuthRules, JwkCache}; + +pub struct LocalBackend { + pub(crate) jwks_cache: JwkCache, + pub(crate) node_info: NodeInfo, +} + +impl LocalBackend { + pub fn new(postgres_addr: SocketAddr) -> Self { + LocalBackend { + jwks_cache: JwkCache::default(), + node_info: NodeInfo { + config: { + let mut cfg = ConnCfg::new(); + cfg.host(&postgres_addr.ip().to_string()); + cfg.port(postgres_addr.port()); + cfg + }, + // TODO(conrad): make this better reflect compute info rather than endpoint info. + aux: MetricsAuxInfo { + endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), + project_id: ProjectIdTag::get_interner().get_or_intern("local"), + branch_id: BranchIdTag::get_interner().get_or_intern("local"), + cold_start_info: ColdStartInfo::WarmCached, + }, + allow_self_signed_compute: false, + }, + } + } +} + +#[derive(Clone, Copy)] +pub(crate) struct StaticAuthRules; + +pub static JWKS_ROLE_MAP: ArcSwapOption = ArcSwapOption::const_empty(); + +#[derive(Debug, Clone)] +pub struct JwksRoleSettings { + pub roles: HashMap, + pub project_id: ProjectIdInt, + pub branch_id: BranchIdInt, +} + +impl FetchAuthRules for StaticAuthRules { + async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result> { + let mappings = JWKS_ROLE_MAP.load(); + let role_mappings = mappings + .as_deref() + .and_then(|m| m.roles.get(&role_name)) + .context("JWKs settings for this role were not configured")?; + let mut rules = vec![]; + for setting in &role_mappings.jwks { + rules.push(AuthRule { + id: setting.id.clone(), + jwks_url: setting.jwks_url.clone(), + audience: setting.jwt_audience.clone(), + }); + } + + Ok(rules) + } +} diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/web.rs similarity index 86% rename from proxy/src/auth/backend/link.rs rename to proxy/src/auth/backend/web.rs index 95f4614736..58a4bef62e 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/web.rs @@ -13,7 +13,7 @@ use tokio_postgres::config::SslMode; use tracing::{info, info_span}; #[derive(Debug, Error)] -pub enum LinkAuthError { +pub(crate) enum WebAuthError { #[error(transparent)] WaiterRegister(#[from] waiters::RegisterError), @@ -24,18 +24,18 @@ pub enum LinkAuthError { Io(#[from] std::io::Error), } -impl UserFacingError for LinkAuthError { +impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() } } -impl ReportableError for LinkAuthError { +impl ReportableError for WebAuthError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service, - LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service, - LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + Self::WaiterRegister(_) => crate::error::ErrorKind::Service, + Self::WaiterWait(_) => crate::error::ErrorKind::Service, + Self::Io(_) => crate::error::ErrorKind::ClientDisconnect, } } } @@ -52,7 +52,7 @@ fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String { ) } -pub fn new_psql_session_id() -> String { +pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } @@ -74,7 +74,7 @@ pub(super) async fn authenticate( } }; - let span = info_span!("link", psql_session_id = &psql_session_id); + let span = info_span!("web", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); // Give user a URL to spawn a new database. @@ -87,7 +87,7 @@ pub(super) async fn authenticate( // Wait for web console response (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); - let db_info = waiter.await.map_err(LinkAuthError::from)?; + let db_info = waiter.await.map_err(WebAuthError::from)?; client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 8f4a392131..0e91ae570a 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -16,7 +16,7 @@ use thiserror::Error; use tracing::{info, warn}; #[derive(Debug, Error, PartialEq, Eq, Clone)] -pub enum ComputeUserInfoParseError { +pub(crate) enum ComputeUserInfoParseError { #[error("Parameter '{0}' is missing in startup packet.")] MissingKey(&'static str), @@ -51,20 +51,20 @@ impl ReportableError for ComputeUserInfoParseError { /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ComputeUserInfoMaybeEndpoint { - pub user: RoleName, - pub endpoint_id: Option, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfoMaybeEndpoint { + pub(crate) user: RoleName, + pub(crate) endpoint_id: Option, + pub(crate) options: NeonOptions, } impl ComputeUserInfoMaybeEndpoint { #[inline] - pub fn endpoint(&self) -> Option<&str> { + pub(crate) fn endpoint(&self) -> Option<&str> { self.endpoint_id.as_deref() } } -pub fn endpoint_sni( +pub(crate) fn endpoint_sni( sni: &str, common_names: &HashSet, ) -> Result, ComputeUserInfoParseError> { @@ -83,16 +83,18 @@ pub fn endpoint_sni( } impl ComputeUserInfoMaybeEndpoint { - pub fn parse( + pub(crate) fn parse( ctx: &RequestMonitoring, params: &StartupMessageParams, sni: Option<&str>, common_names: Option<&HashSet>, ) -> Result { - use ComputeUserInfoParseError::*; - // Some parameters are stored in the startup message. - let get_param = |key| params.get(key).ok_or(MissingKey(key)); + let get_param = |key| { + params + .get(key) + .ok_or(ComputeUserInfoParseError::MissingKey(key)) + }; let user: RoleName = get_param("user")?.into(); // Project name might be passed via PG's command-line options. @@ -122,12 +124,18 @@ impl ComputeUserInfoMaybeEndpoint { let endpoint = match (endpoint_option, endpoint_from_domain) { // Invariant: if we have both project name variants, they should match. (Some(option), Some(domain)) if option != domain => { - Some(Err(InconsistentProjectNames { domain, option })) + Some(Err(ComputeUserInfoParseError::InconsistentProjectNames { + domain, + option, + })) } // Invariant: project name may not contain certain characters. - (a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) { - false => Err(MalformedProjectName(name)), - true => Ok(name), + (a, b) => a.or(b).map(|name| { + if project_name_valid(name.as_ref()) { + Ok(name) + } else { + Err(ComputeUserInfoParseError::MalformedProjectName(name)) + } }), } .transpose()?; @@ -165,12 +173,12 @@ impl ComputeUserInfoMaybeEndpoint { } } -pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { +pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern)) } #[derive(Debug, Clone, Eq, PartialEq)] -pub enum IpPattern { +pub(crate) enum IpPattern { Subnet(ipnet::IpNet), Range(IpAddr, IpAddr), Single(IpAddr), @@ -186,7 +194,7 @@ impl<'de> serde::de::Deserialize<'de> for IpPattern { impl<'de> serde::de::Visitor<'de> for StrVisitor { type Value = IpPattern; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask") } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index acf7b4f6b6..f7e2b5296e 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -17,17 +17,20 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; /// Every authentication selector is supposed to implement this trait. -pub trait AuthMethod { +pub(crate) trait AuthMethod { /// Any authentication selector should provide initial backend message /// containing auth method name and parameters, e.g. md5 salt. fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; } /// Initial state of [`AuthFlow`]. -pub struct Begin; +pub(crate) struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring); +pub(crate) struct Scram<'a>( + pub(crate) &'a scram::ServerSecret, + pub(crate) &'a RequestMonitoring, +); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -44,7 +47,7 @@ impl AuthMethod for Scram<'_> { /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in /// . -pub struct PasswordHack; +pub(crate) struct PasswordHack; impl AuthMethod for PasswordHack { #[inline(always)] @@ -55,10 +58,10 @@ impl AuthMethod for PasswordHack { /// Use clear-text password auth called `password` in docs /// -pub struct CleartextPassword { - pub pool: Arc, - pub endpoint: EndpointIdInt, - pub secret: AuthSecret, +pub(crate) struct CleartextPassword { + pub(crate) pool: Arc, + pub(crate) endpoint: EndpointIdInt, + pub(crate) secret: AuthSecret, } impl AuthMethod for CleartextPassword { @@ -70,7 +73,7 @@ impl AuthMethod for CleartextPassword { /// This wrapper for [`PqStream`] performs client authentication. #[must_use] -pub struct AuthFlow<'a, S, State> { +pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, /// State might contain ancillary data (see [`Self::begin`]). @@ -81,7 +84,7 @@ pub struct AuthFlow<'a, S, State> { /// Initial state of the stream wrapper. impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { /// Create a new wrapper for client authentication. - pub fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { @@ -92,7 +95,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { } /// Move to the next step by sending auth method's name & params to client. - pub async fn begin(self, method: M) -> io::Result> { + pub(crate) async fn begin(self, method: M) -> io::Result> { self.stream .write_message(&method.first_message(self.tls_server_end_point.supported())) .await?; @@ -107,7 +110,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn get_password(self) -> super::Result { + pub(crate) async fn get_password(self) -> super::Result { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -126,7 +129,7 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result> { + pub(crate) async fn authenticate(self) -> super::Result> { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -151,7 +154,7 @@ impl AuthFlow<'_, S, CleartextPassword> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result> { + pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; // pause the timer while we communicate with the client diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index 2ddf46fe25..8585b8ff48 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -1,5 +1,5 @@ //! Payload for ad hoc authentication method for clients that don't support SNI. -//! See the `impl` for [`super::backend::BackendType`]. +//! See the `impl` for [`super::backend::Backend`]. //! Read more: . //! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified. @@ -7,13 +7,13 @@ use bstr::ByteSlice; use crate::EndpointId; -pub struct PasswordHackPayload { - pub endpoint: EndpointId, - pub password: Vec, +pub(crate) struct PasswordHackPayload { + pub(crate) endpoint: EndpointId, + pub(crate) password: Vec, } impl PasswordHackPayload { - pub fn parse(bytes: &[u8]) -> Option { + pub(crate) fn parse(bytes: &[u8]) -> Option { // The format is `project=;` or `project=$`. let separators = [";", "$"]; for sep in separators { @@ -30,7 +30,7 @@ impl PasswordHackPayload { } } -pub fn parse_endpoint_param(bytes: &str) -> Option<&str> { +pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> { bytes .strip_prefix("project=") .or_else(|| bytes.strip_prefix("endpoint=")) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs new file mode 100644 index 0000000000..08effeff99 --- /dev/null +++ b/proxy/src/bin/local_proxy.rs @@ -0,0 +1,316 @@ +use std::{ + net::SocketAddr, + path::{Path, PathBuf}, + pin::pin, + sync::Arc, + time::Duration, +}; + +use anyhow::{bail, ensure}; +use dashmap::DashMap; +use futures::{future::Either, FutureExt}; +use proxy::{ + auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP}, + cancellation::CancellationHandlerMain, + config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, + console::{locks::ApiLocks, messages::JwksRoleMapping}, + http::health_server::AppMetrics, + metrics::{Metrics, ThreadPoolMetrics}, + rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}, + scram::threadpool::ThreadPool, + serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions}, +}; + +project_git_version!(GIT_VERSION); +project_build_tag!(BUILD_TAG); + +use clap::Parser; +use tokio::{net::TcpListener, task::JoinSet}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; +use utils::{project_build_tag, project_git_version, sentry_init::init_sentry}; + +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +/// Neon proxy/router +#[derive(Parser)] +#[command(version = GIT_VERSION, about)] +struct LocalProxyCliArgs { + /// listen for incoming metrics connections on ip:port + #[clap(long, default_value = "127.0.0.1:7001")] + metrics: String, + /// listen for incoming http connections on ip:port + #[clap(long)] + http: String, + /// timeout for the TLS handshake + #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] + handshake_timeout: tokio::time::Duration, + /// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable). + #[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)] + connect_compute_lock: String, + #[clap(flatten)] + sql_over_http: SqlOverHttpArgs, + /// User rate limiter max number of requests per second. + /// + /// Provided in the form `@`. + /// Can be given multiple times for different bucket sizes. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] + user_rps_limit: Vec, + /// Whether the auth rate limiter actually takes effect (for testing) + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + auth_rate_limit_enabled: bool, + /// Authentication rate limiter max number of hashes per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] + auth_rate_limit: Vec, + /// The IP subnet to use when considering whether two IP addresses are considered the same. + #[clap(long, default_value_t = 64)] + auth_rate_limit_ip_subnet: u8, + /// Whether to retry the connection to the compute node + #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] + connect_to_compute_retry: String, + /// Address of the postgres server + #[clap(long, default_value = "127.0.0.1:5432")] + compute: SocketAddr, + /// File address of the local proxy config file + #[clap(long, default_value = "./localproxy.json")] + config_path: PathBuf, +} + +#[derive(clap::Args, Clone, Copy, Debug)] +struct SqlOverHttpArgs { + /// How many connections to pool for each endpoint. Excess connections are discarded + #[clap(long, default_value_t = 200)] + sql_over_http_pool_max_total_conns: usize, + + /// How long pooled connections should remain idle for before closing + #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)] + sql_over_http_idle_timeout: tokio::time::Duration, + + #[clap(long, default_value_t = 100)] + sql_over_http_client_conn_threshold: u64, + + #[clap(long, default_value_t = 16)] + sql_over_http_cancel_set_shards: usize, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let _logging_guard = proxy::logging::init().await?; + let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); + let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); + + Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); + + info!("Version: {GIT_VERSION}"); + info!("Build_tag: {BUILD_TAG}"); + let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo { + revision: GIT_VERSION, + build_tag: BUILD_TAG, + }); + + let jemalloc = match proxy::jemalloc::MetricRecorder::new() { + Ok(t) => Some(t), + Err(e) => { + tracing::error!(error = ?e, "could not start jemalloc metrics loop"); + None + } + }; + + let args = LocalProxyCliArgs::parse(); + let config = build_config(&args)?; + + let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?; + let http_listener = TcpListener::bind(args.http).await?; + let shutdown = CancellationToken::new(); + + // todo: should scale with CU + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + LeakyBucketConfig { + rps: 10.0, + max: 100.0, + }, + 16, + )); + + refresh_config(args.config_path.clone()).await; + + let mut maintenance_tasks = JoinSet::new(); + maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || { + refresh_config(args.config_path.clone()).map(Ok) + })); + maintenance_tasks.spawn(proxy::http::health_server::task_main( + metrics_listener, + AppMetrics { + jemalloc, + neon_metrics, + proxy: proxy::metrics::Metrics::get(), + }, + )); + + let task = serverless::task_main( + config, + http_listener, + shutdown.clone(), + Arc::new(CancellationHandlerMain::new( + Arc::new(DashMap::new()), + None, + proxy::metrics::CancellationSource::Local, + )), + endpoint_rate_limiter, + ); + + match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { + // exit immediately on maintenance task completion + Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {}, + // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above) + Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"), + // exit immediately on client task error + Either::Right((res, _)) => res?, + } + + Ok(()) +} + +/// ProxyConfig is created at proxy startup, and lives forever. +fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.connect_compute_lock.parse()?; + info!( + ?limiter, + shards, + ?epoch, + "Using NodeLocks (connect_compute)" + ); + let connect_compute_locks = ApiLocks::new( + "connect_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().proxy.connect_compute_lock, + )?; + + let http_config = HttpConfig { + accept_websockets: false, + pool_options: GlobalConnPoolOptions { + gc_epoch: Duration::from_secs(60), + pool_shards: 2, + idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, + opt_in: false, + + max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, + }, + cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), + client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + }; + + Ok(Box::leak(Box::new(ProxyConfig { + tls_config: None, + auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( + LocalBackend::new(args.compute), + )), + metric_collection: None, + allow_self_signed_compute: false, + http_config, + authentication_config: AuthenticationConfig { + thread_pool: ThreadPool::new(0), + scram_protocol_timeout: Duration::from_secs(10), + rate_limiter_enabled: false, + rate_limiter: BucketRateLimiter::new(vec![]), + rate_limit_ip_subnet: 64, + }, + require_client_ip: false, + handshake_timeout: Duration::from_secs(10), + region: "local".into(), + wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, + connect_compute_locks, + connect_to_compute_retry_config: RetryConfig::parse( + RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES, + )?, + }))) +} + +async fn refresh_config(path: PathBuf) { + match refresh_config_inner(&path).await { + Ok(()) => {} + Err(e) => { + error!(error=?e, ?path, "could not read config file"); + } + } +} + +async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> { + let bytes = tokio::fs::read(&path).await?; + let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?; + + let mut settings = None; + + for mapping in data.roles.values_mut() { + for jwks in &mut mapping.jwks { + ensure!( + jwks.jwks_url.has_authority() + && (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"), + "Invalid JWKS url. Must be HTTP", + ); + + ensure!( + jwks.jwks_url + .host() + .is_some_and(|h| h != url::Host::Domain("")), + "Invalid JWKS url. No domain listed", + ); + + // clear username, password and ports + jwks.jwks_url.set_username("").expect( + "url can be a base and has a valid host and is not a file. should not error", + ); + jwks.jwks_url.set_password(None).expect( + "url can be a base and has a valid host and is not a file. should not error", + ); + // local testing is hard if we need to have a specific restricted port + if cfg!(not(feature = "testing")) { + jwks.jwks_url.set_port(None).expect( + "url can be a base and has a valid host and is not a file. should not error", + ); + } + + // clear query params + jwks.jwks_url.set_fragment(None); + jwks.jwks_url.query_pairs_mut().clear().finish(); + + if jwks.jwks_url.scheme() != "https" { + // local testing is hard if we need to set up https support. + if cfg!(not(feature = "testing")) { + jwks.jwks_url + .set_scheme("https") + .expect("should not error to set the scheme to https if it was http"); + } else { + warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS"); + } + } + + let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id)); + ensure!( + *pr == jwks.project_id, + "inconsistent project IDs configured" + ); + ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured"); + } + } + + if let Some((project_id, branch_id)) = settings { + JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings { + roles: data.roles, + project_id, + branch_id, + }))); + } + + Ok(()) +} diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 1038fa5116..20d2d3df9a 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -133,7 +133,9 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), )); - let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token)); + let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async { + Ok(()) + })); // the signal task cant ever succeed. // the main task can error, or can succeed on cancellation. diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index b44e0ddd2f..7706a1f7cd 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -60,11 +60,14 @@ use clap::{Parser, ValueEnum}; static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[derive(Clone, Debug, ValueEnum)] -enum AuthBackend { +enum AuthBackendType { Console, #[cfg(feature = "testing")] Postgres, - Link, + // clap only shows the name, not the alias, in usage text. + // TODO: swap name/alias and deprecate "link" + #[value(name("link"), alias("web"))] + Web, } /// Neon proxy/router @@ -77,8 +80,8 @@ struct ProxyCliArgs { /// listen for incoming client connections on ip:port #[clap(short, long, default_value = "127.0.0.1:4432")] proxy: String, - #[clap(value_enum, long, default_value_t = AuthBackend::Link)] - auth_backend: AuthBackend, + #[clap(value_enum, long, default_value_t = AuthBackendType::Web)] + auth_backend: AuthBackendType, /// listen for management callback connection on ip:port #[clap(short, long, default_value = "127.0.0.1:7000")] mgmt: String, @@ -88,7 +91,7 @@ struct ProxyCliArgs { /// listen for incoming wss connections on ip:port #[clap(long)] wss: Option, - /// redirect unauthenticated users to the given uri in case of link auth + /// redirect unauthenticated users to the given uri in case of web auth #[clap(short, long, default_value = "http://localhost:3000/psql_session/")] uri: String, /// cloud API endpoint for authenticating users @@ -148,7 +151,7 @@ struct ProxyCliArgs { disable_dynamic_rate_limiter: bool, /// Endpoint rate limiter max number of requests per second. /// - /// Provided in the form '@'. + /// Provided in the form `@`. /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] endpoint_rps_limit: Vec, @@ -173,9 +176,6 @@ struct ProxyCliArgs { /// cache for `role_secret` (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] role_secret_cache: String, - /// disable ip check for http requests. If it is too time consuming, it could be turned off. - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - disable_ip_check_for_http: bool, /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) #[clap(long)] redis_notifications: Option, @@ -450,7 +450,10 @@ async fn main() -> anyhow::Result<()> { // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); - maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone())); + maintenance_tasks.spawn(proxy::handle_signals( + cancellation_token.clone(), + || async { Ok(()) }, + )); maintenance_tasks.spawn(http::health_server::task_main( http_listener, AppMetrics { @@ -470,7 +473,7 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::BackendType::Console(api, _) = &config.auth_backend { + if let auth::Backend::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} @@ -575,7 +578,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } let auth_backend = match &args.auth_backend { - AuthBackend::Console => { + AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let project_info_cache_config: ProjectInfoCacheOptions = args.project_info_cache.parse()?; @@ -624,18 +627,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_endpoint_rate_limiter, ); let api = console::provider::ConsoleBackend::Console(api); - auth::BackendType::Console(MaybeOwned::Owned(api), ()) + auth::Backend::Console(MaybeOwned::Owned(api), ()) } #[cfg(feature = "testing")] - AuthBackend::Postgres => { + AuthBackendType::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); let api = console::provider::ConsoleBackend::Postgres(api); - auth::BackendType::Console(MaybeOwned::Owned(api), ()) + auth::Backend::Console(MaybeOwned::Owned(api), ()) } - AuthBackend::Link => { + AuthBackendType::Web => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url), ()) + auth::Backend::Web(MaybeOwned::Owned(url), ()) } }; @@ -661,6 +664,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { )?; let http_config = HttpConfig { + accept_websockets: true, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index d1d4087241..6c168144a7 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,7 +1,7 @@ -pub mod common; -pub mod endpoints; -pub mod project_info; +pub(crate) mod common; +pub(crate) mod endpoints; +pub(crate) mod project_info; mod timed_lru; -pub use common::{Cache, Cached}; -pub use timed_lru::TimedLru; +pub(crate) use common::{Cache, Cached}; +pub(crate) use timed_lru::TimedLru; diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index 4e393fddb2..b5caf94788 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut}; /// A generic trait which exposes types of cache's key and value, /// as well as the notion of cache entry invalidation. /// This is useful for [`Cached`]. -pub trait Cache { +pub(crate) trait Cache { /// Entry's key. type Key; @@ -24,26 +24,26 @@ impl Cache for &C { type LookupInfo = C::LookupInfo; fn invalidate(&self, info: &Self::LookupInfo) { - C::invalidate(self, info) + C::invalidate(self, info); } } /// Wrapper for convenient entry invalidation. -pub struct Cached::Value> { +pub(crate) struct Cached::Value> { /// Cache + lookup info. - pub token: Option<(C, C::LookupInfo)>, + pub(crate) token: Option<(C, C::LookupInfo)>, /// The value itself. - pub value: V, + pub(crate) value: V, } impl Cached { /// Place any entry into this wrapper; invalidation will be a no-op. - pub fn new_uncached(value: V) -> Self { + pub(crate) fn new_uncached(value: V) -> Self { Self { token: None, value } } - pub fn take_value(self) -> (Cached, V) { + pub(crate) fn take_value(self) -> (Cached, V) { ( Cached { token: self.token, @@ -53,7 +53,7 @@ impl Cached { ) } - pub fn map(self, f: impl FnOnce(V) -> U) -> Cached { + pub(crate) fn map(self, f: impl FnOnce(V) -> U) -> Cached { Cached { token: self.token, value: f(self.value), @@ -61,7 +61,7 @@ impl Cached { } /// Drop this entry from a cache if it's still there. - pub fn invalidate(self) -> V { + pub(crate) fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { cache.invalidate(info); } @@ -69,7 +69,7 @@ impl Cached { } /// Tell if this entry is actually cached. - pub fn cached(&self) -> bool { + pub(crate) fn cached(&self) -> bool { self.token.is_some() } } diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 8c851790c2..f4762232d8 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -28,7 +28,7 @@ use crate::{ }; #[derive(Deserialize, Debug, Clone)] -pub struct ControlPlaneEventKey { +pub(crate) struct ControlPlaneEventKey { endpoint_created: Option, branch_created: Option, project_created: Option, @@ -56,7 +56,7 @@ pub struct EndpointsCache { } impl EndpointsCache { - pub fn new(config: EndpointCacheConfig) -> Self { + pub(crate) fn new(config: EndpointCacheConfig) -> Self { Self { limiter: Arc::new(Mutex::new(GlobalRateLimiter::new( config.limiter_info.clone(), @@ -68,7 +68,7 @@ impl EndpointsCache { ready: AtomicBool::new(false), } } - pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { + pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { if !self.ready.load(Ordering::Acquire) { return true; } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 10cc4ceee1..ceae74a9a0 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -24,7 +24,7 @@ use crate::{ use super::{Cache, Cached}; #[async_trait] -pub trait ProjectInfoCache { +pub(crate) trait ProjectInfoCache { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); @@ -37,7 +37,7 @@ struct Entry { } impl Entry { - pub fn new(value: T) -> Self { + pub(crate) fn new(value: T) -> Self { Self { created_at: Instant::now(), value, @@ -64,7 +64,7 @@ impl EndpointInfo { Some(t) => t < created_at, } } - pub fn get_role_secret( + pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, @@ -81,7 +81,7 @@ impl EndpointInfo { None } - pub fn get_allowed_ips( + pub(crate) fn get_allowed_ips( &self, valid_since: Instant, ignore_cache_since: Option, @@ -96,10 +96,10 @@ impl EndpointInfo { } None } - pub fn invalidate_allowed_ips(&mut self) { + pub(crate) fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } - pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { self.secret.remove(&role_name); } } @@ -178,7 +178,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } impl ProjectInfoCacheImpl { - pub fn new(config: ProjectInfoCacheOptions) -> Self { + pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { Self { cache: DashMap::new(), project2ep: DashMap::new(), @@ -189,7 +189,7 @@ impl ProjectInfoCacheImpl { } } - pub fn get_role_secret( + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, @@ -212,7 +212,7 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } - pub fn get_allowed_ips( + pub(crate) fn get_allowed_ips( &self, endpoint_id: &EndpointId, ) -> Option>>> { @@ -230,7 +230,7 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } - pub fn insert_role_secret( + pub(crate) fn insert_role_secret( &self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, @@ -247,7 +247,7 @@ impl ProjectInfoCacheImpl { entry.secret.insert(role_name, secret.into()); } } - pub fn insert_allowed_ips( + pub(crate) fn insert_allowed_ips( &self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, @@ -274,13 +274,13 @@ impl ProjectInfoCacheImpl { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us != u64::MAX { + let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { + None + } else { let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); // We are fine if entry is not older than ttl or was added before we are getting notifications. valid_since = valid_since.min(ignore_cache_since); Some(ignore_cache_since) - } else { - None }; (valid_since, ignore_cache_since) } @@ -306,7 +306,7 @@ impl ProjectInfoCacheImpl { let mut removed = 0; let shard = self.project2ep.shards()[shard].write(); for (_, endpoints) in shard.iter() { - for endpoint in endpoints.get().iter() { + for endpoint in endpoints.get() { self.cache.remove(endpoint); removed += 1; } @@ -319,7 +319,7 @@ impl ProjectInfoCacheImpl { /// Lookup info for project info cache. /// This is used to invalidate cache entries. -pub struct CachedLookupInfo { +pub(crate) struct CachedLookupInfo { /// Search by this key. endpoint_id: EndpointIdInt, lookup_type: LookupType, diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index c5c4f6a1ed..8bb482f7c6 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -39,7 +39,7 @@ use super::{common::Cached, *}; /// /// * It's possible for an entry that has not yet expired entry to be evicted /// before expired items. That's a bit wasteful, but probably fine in practice. -pub struct TimedLru { +pub(crate) struct TimedLru { /// Cache's name for tracing. name: &'static str, @@ -58,7 +58,7 @@ impl Cache for TimedLru { type LookupInfo = LookupInfo; fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info) + self.invalidate_raw(info); } } @@ -72,7 +72,7 @@ struct Entry { impl TimedLru { /// Construct a new LRU cache with timed entries. - pub fn new( + pub(crate) fn new( name: &'static str, capacity: usize, ttl: Duration, @@ -207,11 +207,11 @@ impl TimedLru { } impl TimedLru { - pub fn insert_ttl(&self, key: K, value: V, ttl: Duration) { + pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) { self.insert_raw_ttl(key, value, ttl, false); } - pub fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (created_at, old) = self.insert_raw(key.clone(), value); let cached = Cached { @@ -221,22 +221,11 @@ impl TimedLru { (old, cached) } - - pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { - let (created_at, old) = self.insert_raw(key.clone(), value.clone()); - - let cached = Cached { - token: Some((self, LookupInfo { created_at, key })), - value, - }; - - (old, cached) - } } impl TimedLru { /// Retrieve a cached entry in convenient wrapper. - pub fn get(&self, key: &Q) -> Option> + pub(crate) fn get(&self, key: &Q) -> Option> where K: Borrow + Clone, Q: Hash + Eq + ?Sized, @@ -253,32 +242,10 @@ impl TimedLru { } }) } - - /// Retrieve a cached entry in convenient wrapper, ignoring its TTL. - pub fn get_ignoring_ttl(&self, key: &Q) -> Option> - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let mut cache = self.cache.lock(); - cache - .get(key) - .map(|entry| Cached::new_uncached(entry.value.clone())) - } - - /// Remove an entry from the cache. - pub fn remove(&self, key: &Q) -> Option - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - let mut cache = self.cache.lock(); - cache.remove(key).map(|entry| entry.value) - } } /// Lookup information for key invalidation. -pub struct LookupInfo { +pub(crate) struct LookupInfo { /// Time of creation of a cache [`Entry`]. /// We use this during invalidation lookups to prevent eviction of a newer /// entry sharing the same key (it might've been inserted by a different diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 34512e9f5b..71a2a16af8 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -18,7 +18,7 @@ use crate::{ pub type CancelMap = Arc>>; pub type CancellationHandlerMain = CancellationHandler>>>; -pub type CancellationHandlerMainInternal = Option>>; +pub(crate) type CancellationHandlerMainInternal = Option>>; /// Enables serving `CancelRequest`s. /// @@ -32,7 +32,7 @@ pub struct CancellationHandler

{ } #[derive(Debug, Error)] -pub enum CancelError { +pub(crate) enum CancelError { #[error("{0}")] IO(#[from] std::io::Error), #[error("{0}")] @@ -53,7 +53,7 @@ impl ReportableError for CancelError { impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub fn get_session(self: Arc) -> Session

{ + pub(crate) fn get_session(self: Arc) -> Session

{ // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -81,7 +81,7 @@ impl CancellationHandler

{ } /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. - pub async fn cancel_session( + pub(crate) async fn cancel_session( &self, key: CancelKeyData, session_id: Uuid, @@ -155,14 +155,14 @@ pub struct CancelClosure { } impl CancelClosure { - pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { + pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { Self { socket_addr, cancel_token, } } /// Cancels the query running on user's compute node. - pub async fn try_cancel_query(self) -> Result<(), CancelError> { + pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; self.cancel_token.cancel_query_raw(socket, NoTls).await?; info!("query was cancelled"); @@ -171,7 +171,7 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session

{ +pub(crate) struct Session

{ /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. @@ -181,7 +181,7 @@ pub struct Session

{ impl

Session

{ /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. - pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { + pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancellation_handler .map @@ -220,7 +220,8 @@ mod tests { #[tokio::test] async fn cancel_session_noop_regression() { - let handler = CancellationHandler::<()>::new(Default::default(), CancellationSource::Local); + let handler = + CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local); handler .cancel_session( CancelKeyData { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 18c82fe379..8d3cb8ee3c 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -23,7 +23,7 @@ use tracing::{error, info, warn}; const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] -pub enum ConnectionError { +pub(crate) enum ConnectionError { /// This error doesn't seem to reveal any secrets; for instance, /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] @@ -44,11 +44,10 @@ pub enum ConnectionError { impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { - use ConnectionError::*; match self { // This helps us drop irrelevant library-specific prefixes. // TODO: propagate severity level and other parameters. - Postgres(err) => match err.as_db_error() { + ConnectionError::Postgres(err) => match err.as_db_error() { Some(err) => { let msg = err.message(); @@ -62,8 +61,8 @@ impl UserFacingError for ConnectionError { } None => err.to_string(), }, - WakeComputeError(err) => err.to_string_client(), - TooManyConnectionAttempts(_) => { + ConnectionError::WakeComputeError(err) => err.to_string_client(), + ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } _ => COULD_NOT_CONNECT.to_owned(), @@ -87,22 +86,22 @@ impl ReportableError for ConnectionError { } /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. -pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; +pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. #[derive(Clone, Default)] -pub struct ConnCfg(Box); +pub(crate) struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self::default() } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: Self) { + pub(crate) fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -112,7 +111,7 @@ impl ConnCfg { } } - pub fn get_host(&self) -> Result { + pub(crate) fn get_host(&self) -> Result { match self.0.get_hosts() { [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()), // we should not have multiple address or unix addresses. @@ -123,15 +122,15 @@ impl ConnCfg { } /// Apply startup message params to the connection config. - pub fn set_startup_params(&mut self, params: &StartupMessageParams) { + pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) { // Only set `user` if it's not present in the config. - // Link auth flow takes username from the console's response. + // Web auth flow takes username from the console's response. if let (None, Some(user)) = (self.get_user(), params.get("user")) { self.user(user); } // Only set `dbname` if it's not present in the config. - // Link auth flow takes dbname from the console's response. + // Web auth flow takes dbname from the console's response. if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { self.dbname(dbname); } @@ -256,25 +255,25 @@ impl ConnCfg { } } -pub struct PostgresConnection { +pub(crate) struct PostgresConnection { /// Socket connected to a compute node. - pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream< + pub(crate) stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream< tokio::net::TcpStream, tokio_postgres_rustls::RustlsStream, >, /// PostgreSQL connection parameters. - pub params: std::collections::HashMap, + pub(crate) params: std::collections::HashMap, /// Query cancellation token. - pub cancel_closure: CancelClosure, + pub(crate) cancel_closure: CancelClosure, /// Labels for proxy's metrics. - pub aux: MetricsAuxInfo, + pub(crate) aux: MetricsAuxInfo, _guage: NumDbConnectionsGuard<'static>, } impl ConnCfg { /// Connect to a corresponding compute node. - pub async fn connect( + pub(crate) async fn connect( &self, ctx: &RequestMonitoring, allow_self_signed_compute: bool, @@ -287,7 +286,7 @@ impl ConnCfg { let client_config = if allow_self_signed_compute { // Allow all certificates for creating the connection - let verifier = Arc::new(AcceptEverythingVerifier) as Arc; + let verifier = Arc::new(AcceptEverythingVerifier); rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(verifier) @@ -366,16 +365,16 @@ static TLS_ROOTS: OnceCell> = OnceCell::new(); struct AcceptEverythingVerifier; impl ServerCertVerifier for AcceptEverythingVerifier { fn supported_verify_schemes(&self) -> Vec { - use rustls::SignatureScheme::*; + use rustls::SignatureScheme; // The schemes for which `SignatureScheme::supported_in_tls13` returns true. vec![ - ECDSA_NISTP521_SHA512, - ECDSA_NISTP384_SHA384, - ECDSA_NISTP256_SHA256, - RSA_PSS_SHA512, - RSA_PSS_SHA384, - RSA_PSS_SHA256, - ED25519, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::ED25519, ] } fn verify_server_cert( diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 1412095505..d7fc6eee22 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -25,7 +25,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, (), ()>, + pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, @@ -52,6 +52,7 @@ pub struct TlsConfig { } pub struct HttpConfig { + pub accept_websockets: bool, pub pool_options: GlobalConnPoolOptions, pub cancel_set: CancelSet, pub client_conn_threshold: u64, @@ -155,7 +156,7 @@ pub enum TlsServerEndPoint { } impl TlsServerEndPoint { - pub fn new(cert: &CertificateDer) -> anyhow::Result { + pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result { let sha256_oids = [ // I'm explicitly not adding MD5 or SHA1 here... They're bad. oid_registry::OID_SIG_ECDSA_WITH_SHA256, @@ -246,7 +247,7 @@ impl CertResolver { let common_name = pem.subject().to_string(); - // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as + // We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names // and passed None instead, which blows up number of cases downstream code should handle. Proper coding @@ -278,7 +279,7 @@ impl CertResolver { impl rustls::server::ResolvesServerCert for CertResolver { fn resolve( &self, - client_hello: rustls::server::ClientHello, + client_hello: rustls::server::ClientHello<'_>, ) -> Option> { self.resolve(client_hello.server_name()).map(|x| x.0) } @@ -317,7 +318,7 @@ impl CertResolver { // a) Instead of multi-cert approach use single cert with extra // domains listed in Subject Alternative Name (SAN). // b) Deploy separate proxy instances for extra domains. - self.default.as_ref().cloned() + self.default.clone() } } } @@ -559,7 +560,7 @@ impl RetryConfig { match key { "num_retries" => num_retries = Some(value.parse()?), "base_retry_wait_duration" => { - base_retry_wait_duration = Some(humantime::parse_duration(value)?) + base_retry_wait_duration = Some(humantime::parse_duration(value)?); } "retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?), unknown => bail!("unknown key: {unknown}"), diff --git a/proxy/src/console.rs b/proxy/src/console.rs index ea95e83437..87d8e781aa 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -10,7 +10,7 @@ pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; /// Various cache-related types. pub mod caches { - pub use super::provider::{ApiCaches, NodeInfoCache}; + pub use super::provider::ApiCaches; } /// Various cache-related types. diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 9abf24ab7f..a48c7316f6 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,37 +1,38 @@ use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fmt::{self, Display}; use crate::auth::IpPattern; use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}; use crate::proxy::retry::CouldRetry; +use crate::RoleName; /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize, Clone)] -pub struct ConsoleError { - pub error: Box, +pub(crate) struct ConsoleError { + pub(crate) error: Box, #[serde(skip)] - pub http_status_code: http::StatusCode, - pub status: Option, + pub(crate) http_status_code: http::StatusCode, + pub(crate) status: Option, } impl ConsoleError { - pub fn get_reason(&self) -> Reason { + pub(crate) fn get_reason(&self) -> Reason { self.status .as_ref() .and_then(|s| s.details.error_info.as_ref()) - .map(|e| e.reason) - .unwrap_or(Reason::Unknown) + .map_or(Reason::Unknown, |e| e.reason) } - pub fn get_user_facing_message(&self) -> String { + + pub(crate) fn get_user_facing_message(&self) -> String { use super::provider::errors::REQUEST_FAILED; self.status .as_ref() .and_then(|s| s.details.user_facing_message.as_ref()) - .map(|m| m.message.clone().into()) - .unwrap_or_else(|| { + .map_or_else(|| { // Ask @neondatabase/control-plane for review before adding more. match self.http_status_code { http::StatusCode::NOT_FOUND => { @@ -48,19 +49,18 @@ impl ConsoleError { } _ => REQUEST_FAILED.to_owned(), } - }) + }, |m| m.message.clone().into()) } } impl Display for ConsoleError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let msg = self + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg: &str = self .status .as_ref() .and_then(|s| s.details.user_facing_message.as_ref()) - .map(|m| m.message.as_ref()) - .unwrap_or_else(|| &self.error); - write!(f, "{}", msg) + .map_or_else(|| self.error.as_ref(), |m| m.message.as_ref()); + write!(f, "{msg}") } } @@ -88,27 +88,28 @@ impl CouldRetry for ConsoleError { } #[derive(Debug, Deserialize, Clone)] -pub struct Status { - pub code: Box, - pub message: Box, - pub details: Details, +#[allow(dead_code)] +pub(crate) struct Status { + pub(crate) code: Box, + pub(crate) message: Box, + pub(crate) details: Details, } #[derive(Debug, Deserialize, Clone)] -pub struct Details { - pub error_info: Option, - pub retry_info: Option, - pub user_facing_message: Option, +pub(crate) struct Details { + pub(crate) error_info: Option, + pub(crate) retry_info: Option, + pub(crate) user_facing_message: Option, } #[derive(Copy, Clone, Debug, Deserialize)] -pub struct ErrorInfo { - pub reason: Reason, +pub(crate) struct ErrorInfo { + pub(crate) reason: Reason, // Schema could also have `metadata` field, but it's not structured. Skip it for now. } #[derive(Clone, Copy, Debug, Deserialize, Default)] -pub enum Reason { +pub(crate) enum Reason { /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles. #[serde(rename = "ROLE_PROTECTED")] RoleProtected, @@ -168,7 +169,7 @@ pub enum Reason { } impl Reason { - pub fn is_not_found(&self) -> bool { + pub(crate) fn is_not_found(self) -> bool { matches!( self, Reason::ResourceNotFound @@ -178,7 +179,7 @@ impl Reason { ) } - pub fn can_retry(&self) -> bool { + pub(crate) fn can_retry(self) -> bool { match self { // do not retry role protected errors // not a transitive error @@ -208,22 +209,23 @@ impl Reason { } #[derive(Copy, Clone, Debug, Deserialize)] -pub struct RetryInfo { - pub retry_delay_ms: u64, +#[allow(dead_code)] +pub(crate) struct RetryInfo { + pub(crate) retry_delay_ms: u64, } #[derive(Debug, Deserialize, Clone)] -pub struct UserFacingMessage { - pub message: Box, +pub(crate) struct UserFacingMessage { + pub(crate) message: Box, } /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`]. /// Returned by the `/proxy_get_role_secret` API method. #[derive(Deserialize)] -pub struct GetRoleSecret { - pub role_secret: Box, - pub allowed_ips: Option>, - pub project_id: Option, +pub(crate) struct GetRoleSecret { + pub(crate) role_secret: Box, + pub(crate) allowed_ips: Option>, + pub(crate) project_id: Option, } // Manually implement debug to omit sensitive info. @@ -236,21 +238,21 @@ impl fmt::Debug for GetRoleSecret { /// Response which holds compute node's `host:port` pair. /// Returned by the `/proxy_wake_compute` API method. #[derive(Debug, Deserialize)] -pub struct WakeCompute { - pub address: Box, - pub aux: MetricsAuxInfo, +pub(crate) struct WakeCompute { + pub(crate) address: Box, + pub(crate) aux: MetricsAuxInfo, } -/// Async response which concludes the link auth flow. +/// Async response which concludes the web auth flow. /// Also known as `kickResponse` in the console. #[derive(Debug, Deserialize)] -pub struct KickSession<'a> { +pub(crate) struct KickSession<'a> { /// Session ID is assigned by the proxy. - pub session_id: &'a str, + pub(crate) session_id: &'a str, /// Compute node connection params. #[serde(deserialize_with = "KickSession::parse_db_info")] - pub result: DatabaseInfo, + pub(crate) result: DatabaseInfo, } impl KickSession<'_> { @@ -273,20 +275,20 @@ impl KickSession<'_> { /// Compute node connection params. #[derive(Deserialize)] -pub struct DatabaseInfo { - pub host: Box, - pub port: u16, - pub dbname: Box, - pub user: Box, +pub(crate) struct DatabaseInfo { + pub(crate) host: Box, + pub(crate) port: u16, + pub(crate) dbname: Box, + pub(crate) user: Box, /// Console always provides a password, but it might /// be inconvenient for debug with local PG instance. - pub password: Option>, - pub aux: MetricsAuxInfo, + pub(crate) password: Option>, + pub(crate) aux: MetricsAuxInfo, } // Manually implement debug to omit sensitive info. impl fmt::Debug for DatabaseInfo { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("DatabaseInfo") .field("host", &self.host) .field("port", &self.port) @@ -299,12 +301,12 @@ impl fmt::Debug for DatabaseInfo { /// Various labels for prometheus metrics. /// Also known as `ProxyMetricsAuxInfo` in the console. #[derive(Debug, Deserialize, Clone)] -pub struct MetricsAuxInfo { - pub endpoint_id: EndpointIdInt, - pub project_id: ProjectIdInt, - pub branch_id: BranchIdInt, +pub(crate) struct MetricsAuxInfo { + pub(crate) endpoint_id: EndpointIdInt, + pub(crate) project_id: ProjectIdInt, + pub(crate) branch_id: BranchIdInt, #[serde(default)] - pub cold_start_info: ColdStartInfo, + pub(crate) cold_start_info: ColdStartInfo, } #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)] @@ -331,7 +333,7 @@ pub enum ColdStartInfo { } impl ColdStartInfo { - pub fn as_str(&self) -> &'static str { + pub(crate) fn as_str(self) -> &'static str { match self { ColdStartInfo::Unknown => "unknown", ColdStartInfo::Warm => "warm", @@ -343,6 +345,26 @@ impl ColdStartInfo { } } +#[derive(Debug, Deserialize, Clone)] +pub struct JwksRoleMapping { + pub roles: HashMap, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct EndpointJwksResponse { + pub jwks: Vec, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct JwksSettings { + pub id: String, + pub project_id: ProjectIdInt, + pub branch_id: BranchIdInt, + pub jwks_url: url::Url, + pub provider_name: String, + pub jwt_audience: Option, +} + #[cfg(test)] mod tests { use super::*; @@ -373,7 +395,7 @@ mod tests { } } }); - let _: KickSession = serde_json::from_str(&json.to_string())?; + let _: KickSession<'_> = serde_json::from_str(&json.to_string())?; Ok(()) } diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index befe7d7510..2ed4f5f206 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -14,18 +14,18 @@ use tracing::{error, info, info_span, Instrument}; static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); /// Give caller an opportunity to wait for the cloud's reply. -pub fn get_waiter( +pub(crate) fn get_waiter( psql_session_id: impl Into, ) -> Result, waiters::RegisterError> { CPLANE_WAITERS.register(psql_session_id.into()) } -pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { +pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { CPLANE_WAITERS.notify(psql_session_id, msg) } /// Console management API listener task. -/// It spawns console response handlers needed for the link auth. +/// It spawns console response handlers needed for the web auth. pub async fn task_main(listener: TcpListener) -> anyhow::Result { scopeguard::defer! { info!("mgmt has shut down"); @@ -74,7 +74,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = DatabaseInfo; +pub(crate) type ComputeReady = DatabaseInfo; // TODO: replace with an http-based protocol. struct MgmtHandler; @@ -93,7 +93,8 @@ impl postgres_backend::Handler for MgmtHandler { } fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> { - let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?; + let resp: KickSession<'_> = + serde_json::from_str(query).context("Failed to parse query as json")?; let span = info_span!("event", session_id = resp.session_id); let _enter = span.enter(); diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 15fc0134b3..12a6e2f12a 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -23,10 +23,10 @@ use std::{hash::Hash, sync::Arc, time::Duration}; use tokio::time::Instant; use tracing::info; -pub mod errors { +pub(crate) mod errors { use crate::{ console::messages::{self, ConsoleError, Reason}, - error::{io_error, ReportableError, UserFacingError}, + error::{io_error, ErrorKind, ReportableError, UserFacingError}, proxy::retry::CouldRetry, }; use thiserror::Error; @@ -34,11 +34,11 @@ pub mod errors { use super::ApiLockError; /// A go-to error message which doesn't leak any detail. - pub const REQUEST_FAILED: &str = "Console request failed"; + pub(crate) const REQUEST_FAILED: &str = "Console request failed"; /// Common console API error. #[derive(Debug, Error)] - pub enum ApiError { + pub(crate) enum ApiError { /// Error returned by the console itself. #[error("{REQUEST_FAILED} with {0}")] Console(ConsoleError), @@ -50,22 +50,20 @@ pub mod errors { impl ApiError { /// Returns HTTP status code if it's the reason for failure. - pub fn get_reason(&self) -> messages::Reason { - use ApiError::*; + pub(crate) fn get_reason(&self) -> messages::Reason { match self { - Console(e) => e.get_reason(), - _ => messages::Reason::Unknown, + ApiError::Console(e) => e.get_reason(), + ApiError::Transport(_) => messages::Reason::Unknown, } } } impl UserFacingError for ApiError { fn to_string_client(&self) -> String { - use ApiError::*; match self { // To minimize risks, only select errors are forwarded to users. - Console(c) => c.get_user_facing_message(), - _ => REQUEST_FAILED.to_owned(), + ApiError::Console(c) => c.get_user_facing_message(), + ApiError::Transport(_) => REQUEST_FAILED.to_owned(), } } } @@ -73,57 +71,53 @@ pub mod errors { impl ReportableError for ApiError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - ApiError::Console(e) => { - use crate::error::ErrorKind::*; - match e.get_reason() { - Reason::RoleProtected => User, - Reason::ResourceNotFound => User, - Reason::ProjectNotFound => User, - Reason::EndpointNotFound => User, - Reason::BranchNotFound => User, - Reason::RateLimitExceeded => ServiceRateLimit, - Reason::NonDefaultBranchComputeTimeExceeded => User, - Reason::ActiveTimeQuotaExceeded => User, - Reason::ComputeTimeQuotaExceeded => User, - Reason::WrittenDataQuotaExceeded => User, - Reason::DataTransferQuotaExceeded => User, - Reason::LogicalSizeQuotaExceeded => User, - Reason::ConcurrencyLimitReached => ControlPlane, - Reason::LockAlreadyTaken => ControlPlane, - Reason::RunningOperations => ControlPlane, - Reason::Unknown => match &e { - ConsoleError { - http_status_code: - http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE, - .. - } => crate::error::ErrorKind::User, - ConsoleError { - http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY, - error, - .. - } if error.contains( - "compute time quota of non-primary branches is exceeded", - ) => - { - crate::error::ErrorKind::User - } - ConsoleError { - http_status_code: http::StatusCode::LOCKED, - error, - .. - } if error.contains("quota exceeded") - || error.contains("the limit for current plan reached") => - { - crate::error::ErrorKind::User - } - ConsoleError { - http_status_code: http::StatusCode::TOO_MANY_REQUESTS, - .. - } => crate::error::ErrorKind::ServiceRateLimit, - ConsoleError { .. } => crate::error::ErrorKind::ControlPlane, - }, - } - } + ApiError::Console(e) => match e.get_reason() { + Reason::RoleProtected => ErrorKind::User, + Reason::ResourceNotFound => ErrorKind::User, + Reason::ProjectNotFound => ErrorKind::User, + Reason::EndpointNotFound => ErrorKind::User, + Reason::BranchNotFound => ErrorKind::User, + Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, + Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User, + Reason::ActiveTimeQuotaExceeded => ErrorKind::User, + Reason::ComputeTimeQuotaExceeded => ErrorKind::User, + Reason::WrittenDataQuotaExceeded => ErrorKind::User, + Reason::DataTransferQuotaExceeded => ErrorKind::User, + Reason::LogicalSizeQuotaExceeded => ErrorKind::User, + Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane, + Reason::LockAlreadyTaken => ErrorKind::ControlPlane, + Reason::RunningOperations => ErrorKind::ControlPlane, + Reason::Unknown => match &e { + ConsoleError { + http_status_code: + http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE, + .. + } => crate::error::ErrorKind::User, + ConsoleError { + http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY, + error, + .. + } if error + .contains("compute time quota of non-primary branches is exceeded") => + { + crate::error::ErrorKind::User + } + ConsoleError { + http_status_code: http::StatusCode::LOCKED, + error, + .. + } if error.contains("quota exceeded") + || error.contains("the limit for current plan reached") => + { + crate::error::ErrorKind::User + } + ConsoleError { + http_status_code: http::StatusCode::TOO_MANY_REQUESTS, + .. + } => crate::error::ErrorKind::ServiceRateLimit, + ConsoleError { .. } => crate::error::ErrorKind::ControlPlane, + }, + }, ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane, } } @@ -152,7 +146,7 @@ pub mod errors { } #[derive(Debug, Error)] - pub enum GetAuthInfoError { + pub(crate) enum GetAuthInfoError { // We shouldn't include the actual secret here. #[error("Console responded with a malformed auth secret")] BadSecret, @@ -170,12 +164,11 @@ pub mod errors { impl UserFacingError for GetAuthInfoError { fn to_string_client(&self) -> String { - use GetAuthInfoError::*; match self { // We absolutely should not leak any secrets! - BadSecret => REQUEST_FAILED.to_owned(), + Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. - ApiError(e) => e.to_string_client(), + Self::ApiError(e) => e.to_string_client(), } } } @@ -183,14 +176,14 @@ pub mod errors { impl ReportableError for GetAuthInfoError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane, - GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane, + Self::BadSecret => crate::error::ErrorKind::ControlPlane, + Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, } } } #[derive(Debug, Error)] - pub enum WakeComputeError { + pub(crate) enum WakeComputeError { #[error("Console responded with a malformed compute address: {0}")] BadComputeAddress(Box), @@ -213,17 +206,16 @@ pub mod errors { impl UserFacingError for WakeComputeError { fn to_string_client(&self) -> String { - use WakeComputeError::*; match self { // We shouldn't show user the address even if it's broken. // Besides, user is unlikely to care about this detail. - BadComputeAddress(_) => REQUEST_FAILED.to_owned(), + Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. - ApiError(e) => e.to_string_client(), + Self::ApiError(e) => e.to_string_client(), - TooManyConnections => self.to_string(), + Self::TooManyConnections => self.to_string(), - TooManyConnectionAttempts(_) => { + Self::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } } @@ -233,10 +225,10 @@ pub mod errors { impl ReportableError for WakeComputeError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, - WakeComputeError::ApiError(e) => e.get_error_kind(), - WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit, - WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(), + Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, + Self::ApiError(e) => e.get_error_kind(), + Self::TooManyConnections => crate::error::ErrorKind::RateLimit, + Self::TooManyConnectionAttempts(e) => e.get_error_kind(), } } } @@ -244,10 +236,10 @@ pub mod errors { impl CouldRetry for WakeComputeError { fn could_retry(&self) -> bool { match self { - WakeComputeError::BadComputeAddress(_) => false, - WakeComputeError::ApiError(e) => e.could_retry(), - WakeComputeError::TooManyConnections => false, - WakeComputeError::TooManyConnectionAttempts(_) => false, + Self::BadComputeAddress(_) => false, + Self::ApiError(e) => e.could_retry(), + Self::TooManyConnections => false, + Self::TooManyConnectionAttempts(_) => false, } } } @@ -255,7 +247,7 @@ pub mod errors { /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] -pub enum AuthSecret { +pub(crate) enum AuthSecret { #[cfg(any(test, feature = "testing"))] /// Md5 hash of user's password. Md5([u8; 16]), @@ -265,32 +257,32 @@ pub enum AuthSecret { } #[derive(Default)] -pub struct AuthInfo { - pub secret: Option, +pub(crate) struct AuthInfo { + pub(crate) secret: Option, /// List of IP addresses allowed for the autorization. - pub allowed_ips: Vec, + pub(crate) allowed_ips: Vec, /// Project ID. This is used for cache invalidation. - pub project_id: Option, + pub(crate) project_id: Option, } /// Info for establishing a connection to a compute node. /// This is what we get after auth succeeded, but not before! #[derive(Clone)] -pub struct NodeInfo { +pub(crate) struct NodeInfo { /// Compute node connection params. /// It's sad that we have to clone this, but this will improve /// once we migrate to a bespoke connection logic. - pub config: compute::ConnCfg, + pub(crate) config: compute::ConnCfg, /// Labels for proxy's metrics. - pub aux: MetricsAuxInfo, + pub(crate) aux: MetricsAuxInfo, /// Whether we should accept self-signed certificates (for testing) - pub allow_self_signed_compute: bool, + pub(crate) allow_self_signed_compute: bool, } impl NodeInfo { - pub async fn connect( + pub(crate) async fn connect( &self, ctx: &RequestMonitoring, timeout: Duration, @@ -304,23 +296,24 @@ impl NodeInfo { ) .await } - pub fn reuse_settings(&mut self, other: Self) { + pub(crate) fn reuse_settings(&mut self, other: Self) { self.allow_self_signed_compute = other.allow_self_signed_compute; self.config.reuse_password(other.config); } - pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) { match keys { ComputeCredentialKeys::Password(password) => self.config.password(password), ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + ComputeCredentialKeys::None => &mut self.config, }; } } -pub type NodeInfoCache = TimedLru>>; -pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub(crate) type NodeInfoCache = TimedLru>>; +pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; +pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; +pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -357,6 +350,7 @@ pub enum ConsoleBackend { Postgres(mock::Api), /// Internal testing #[cfg(test)] + #[allow(private_interfaces)] Test(Box), } @@ -366,13 +360,14 @@ impl Api for ConsoleBackend { ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - use ConsoleBackend::*; match self { - Console(api) => api.get_role_secret(ctx, user_info).await, + Self::Console(api) => api.get_role_secret(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Postgres(api) => api.get_role_secret(ctx, user_info).await, + Self::Postgres(api) => api.get_role_secret(ctx, user_info).await, #[cfg(test)] - Test(_) => unreachable!("this function should never be called in the test backend"), + Self::Test(_) => { + unreachable!("this function should never be called in the test backend") + } } } @@ -381,13 +376,12 @@ impl Api for ConsoleBackend { ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError> { - use ConsoleBackend::*; match self { - Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, #[cfg(test)] - Test(api) => api.get_allowed_ips_and_secret(), + Self::Test(api) => api.get_allowed_ips_and_secret(), } } @@ -396,14 +390,12 @@ impl Api for ConsoleBackend { ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - use ConsoleBackend::*; - match self { - Console(api) => api.wake_compute(ctx, user_info).await, + Self::Console(api) => api.wake_compute(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Postgres(api) => api.wake_compute(ctx, user_info).await, + Self::Postgres(api) => api.wake_compute(ctx, user_info).await, #[cfg(test)] - Test(api) => api.wake_compute(), + Self::Test(api) => api.wake_compute(), } } } @@ -411,7 +403,7 @@ impl Api for ConsoleBackend { /// Various caches for [`console`](super). pub struct ApiCaches { /// Cache for the `wake_compute` API method. - pub node_info: NodeInfoCache, + pub(crate) node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. pub project_info: Arc, /// List of all valid endpoints. @@ -448,7 +440,7 @@ pub struct ApiLocks { } #[derive(Debug, thiserror::Error)] -pub enum ApiLockError { +pub(crate) enum ApiLockError { #[error("timeout acquiring resource permit")] TimeoutError(#[from] tokio::time::error::Elapsed), } @@ -480,7 +472,7 @@ impl ApiLocks { }) } - pub async fn get_permit(&self, key: &K) -> Result { + pub(crate) async fn get_permit(&self, key: &K) -> Result { if self.config.initial_limit == 0 { return Ok(WakeComputePermit { permit: Token::disabled(), @@ -540,18 +532,18 @@ impl ApiLocks { } } -pub struct WakeComputePermit { +pub(crate) struct WakeComputePermit { permit: Token, } impl WakeComputePermit { - pub fn should_check_cache(&self) -> bool { + pub(crate) fn should_check_cache(&self) -> bool { !self.permit.is_disabled() } - pub fn release(self, outcome: Outcome) { - self.permit.release(outcome) + pub(crate) fn release(self, outcome: Outcome) { + self.permit.release(outcome); } - pub fn release_result(self, res: Result) -> Result { + pub(crate) fn release_result(self, res: Result) -> Result { match res { Ok(_) => self.release(Outcome::Success), Err(_) => self.release(Outcome::Overload), diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 2093da7562..08b87cd87a 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -48,7 +48,7 @@ impl Api { Self { endpoint } } - pub fn url(&self) -> &str { + pub(crate) fn url(&self) -> &str { self.endpoint.as_str() } @@ -64,7 +64,7 @@ impl Api { tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; tokio::spawn(connection); - let secret = match get_execute_postgres_query( + let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", &[&&*user_info.user], @@ -72,15 +72,12 @@ impl Api { ) .await? { - Some(entry) => { - info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); - secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) - } - None => { - warn!("user '{}' does not exist", user_info.user); - None - } + info!("got a secret: {entry}"); // safe since it's not a prod scenario + let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); + secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + } else { + warn!("user '{}' does not exist", user_info.user); + None }; let allowed_ips = match get_execute_postgres_query( &client, @@ -142,12 +139,11 @@ async fn get_execute_postgres_query( let rows = client.query(query, params).await?; // We can get at most one row, because `rolname` is unique. - let row = match rows.first() { - Some(row) => row, + let Some(row) = rows.first() else { // This means that the user doesn't exist, so there can be no secret. // However, this is still a *valid* outcome which is very similar // to getting `404 Not found` from the Neon console. - None => return Ok(None), + return Ok(None); }; let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?; diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 7eda238b66..33eda72e65 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -25,8 +25,8 @@ use tracing::{debug, error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, - pub locks: &'static ApiLocks, - pub wake_compute_endpoint_rate_limiter: Arc, + pub(crate) locks: &'static ApiLocks, + pub(crate) wake_compute_endpoint_rate_limiter: Arc, jwt: String, } @@ -38,9 +38,9 @@ impl Api { locks: &'static ApiLocks, wake_compute_endpoint_rate_limiter: Arc, ) -> Self { - let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { + let jwt = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { Ok(v) => v, - Err(_) => "".to_string(), + Err(_) => String::new(), }; Self { endpoint, @@ -51,7 +51,7 @@ impl Api { } } - pub fn url(&self) -> &str { + pub(crate) fn url(&self) -> &str { self.endpoint.url().as_str() } @@ -96,10 +96,10 @@ impl Api { // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry Err(e) => { - if e.get_reason().is_not_found() { - return Ok(AuthInfo::default()); + return if e.get_reason().is_not_found() { + Ok(AuthInfo::default()) } else { - return Err(e.into()); + Err(e.into()) } } }; diff --git a/proxy/src/context.rs b/proxy/src/context.rs index e925f67233..72e1fa1cee 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -22,8 +22,9 @@ use self::parquet::RequestData; pub mod parquet; -pub static LOG_CHAN: OnceCell> = OnceCell::new(); -pub static LOG_CHAN_DISCONNECT: OnceCell> = OnceCell::new(); +pub(crate) static LOG_CHAN: OnceCell> = OnceCell::new(); +pub(crate) static LOG_CHAN_DISCONNECT: OnceCell> = + OnceCell::new(); /// Context data for a single request to connect to a database. /// @@ -38,12 +39,12 @@ pub struct RequestMonitoring( ); struct RequestMonitoringInner { - pub peer_addr: IpAddr, - pub session_id: Uuid, - pub protocol: Protocol, + pub(crate) peer_addr: IpAddr, + pub(crate) session_id: Uuid, + pub(crate) protocol: Protocol, first_packet: chrono::DateTime, region: &'static str, - pub span: Span, + pub(crate) span: Span, // filled in as they are discovered project: Option, @@ -63,15 +64,15 @@ struct RequestMonitoringInner { sender: Option>, // This sender is only used to log the length of session in case of success. disconnect_sender: Option>, - pub latency_timer: LatencyTimer, + pub(crate) latency_timer: LatencyTimer, // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane. rejected: Option, disconnect_timestamp: Option>, } #[derive(Clone, Debug)] -pub enum AuthMethod { - // aka link aka passwordless +pub(crate) enum AuthMethod { + // aka passwordless, fka link Web, ScramSha256, ScramSha256Plus, @@ -125,11 +126,11 @@ impl RequestMonitoring { } #[cfg(test)] - pub fn test() -> Self { + pub(crate) fn test() -> Self { RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test") } - pub fn console_application_name(&self) -> String { + pub(crate) fn console_application_name(&self) -> String { let this = self.0.try_lock().expect("should not deadlock"); format!( "{}/{}", @@ -138,19 +139,19 @@ impl RequestMonitoring { ) } - pub fn set_rejected(&self, rejected: bool) { + pub(crate) fn set_rejected(&self, rejected: bool) { let mut this = self.0.try_lock().expect("should not deadlock"); this.rejected = Some(rejected); } - pub fn set_cold_start_info(&self, info: ColdStartInfo) { + pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) { self.0 .try_lock() .expect("should not deadlock") .set_cold_start_info(info); } - pub fn set_db_options(&self, options: StartupMessageParams) { + pub(crate) fn set_db_options(&self, options: StartupMessageParams) { let mut this = self.0.try_lock().expect("should not deadlock"); this.set_application(options.get("application_name").map(SmolStr::from)); if let Some(user) = options.get("user") { @@ -163,43 +164,43 @@ impl RequestMonitoring { this.pg_options = Some(options); } - pub fn set_project(&self, x: MetricsAuxInfo) { + pub(crate) fn set_project(&self, x: MetricsAuxInfo) { let mut this = self.0.try_lock().expect("should not deadlock"); if this.endpoint_id.is_none() { - this.set_endpoint_id(x.endpoint_id.as_str().into()) + this.set_endpoint_id(x.endpoint_id.as_str().into()); } this.branch = Some(x.branch_id); this.project = Some(x.project_id); this.set_cold_start_info(x.cold_start_info); } - pub fn set_project_id(&self, project_id: ProjectIdInt) { + pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) { let mut this = self.0.try_lock().expect("should not deadlock"); this.project = Some(project_id); } - pub fn set_endpoint_id(&self, endpoint_id: EndpointId) { + pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) { self.0 .try_lock() .expect("should not deadlock") .set_endpoint_id(endpoint_id); } - pub fn set_dbname(&self, dbname: DbName) { + pub(crate) fn set_dbname(&self, dbname: DbName) { self.0 .try_lock() .expect("should not deadlock") .set_dbname(dbname); } - pub fn set_user(&self, user: RoleName) { + pub(crate) fn set_user(&self, user: RoleName) { self.0 .try_lock() .expect("should not deadlock") .set_user(user); } - pub fn set_auth_method(&self, auth_method: AuthMethod) { + pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) { let mut this = self.0.try_lock().expect("should not deadlock"); this.auth_method = Some(auth_method); } @@ -211,7 +212,7 @@ impl RequestMonitoring { .has_private_peer_addr() } - pub fn set_error_kind(&self, kind: ErrorKind) { + pub(crate) fn set_error_kind(&self, kind: ErrorKind) { let mut this = self.0.try_lock().expect("should not deadlock"); // Do not record errors from the private address to metrics. if !this.has_private_peer_addr() { @@ -237,30 +238,30 @@ impl RequestMonitoring { .log_connect(); } - pub fn protocol(&self) -> Protocol { + pub(crate) fn protocol(&self) -> Protocol { self.0.try_lock().expect("should not deadlock").protocol } - pub fn span(&self) -> Span { + pub(crate) fn span(&self) -> Span { self.0.try_lock().expect("should not deadlock").span.clone() } - pub fn session_id(&self) -> Uuid { + pub(crate) fn session_id(&self) -> Uuid { self.0.try_lock().expect("should not deadlock").session_id } - pub fn peer_addr(&self) -> IpAddr { + pub(crate) fn peer_addr(&self) -> IpAddr { self.0.try_lock().expect("should not deadlock").peer_addr } - pub fn cold_start_info(&self) -> ColdStartInfo { + pub(crate) fn cold_start_info(&self) -> ColdStartInfo { self.0 .try_lock() .expect("should not deadlock") .cold_start_info } - pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause { + pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> { LatencyTimerPause { ctx: self, start: tokio::time::Instant::now(), @@ -268,16 +269,16 @@ impl RequestMonitoring { } } - pub fn success(&self) { + pub(crate) fn success(&self) { self.0 .try_lock() .expect("should not deadlock") .latency_timer - .success() + .success(); } } -pub struct LatencyTimerPause<'a> { +pub(crate) struct LatencyTimerPause<'a> { ctx: &'a RequestMonitoring, start: tokio::time::Instant, waiting_for: Waiting, @@ -328,7 +329,7 @@ impl RequestMonitoringInner { fn has_private_peer_addr(&self) -> bool { match self.peer_addr { IpAddr::V4(ip) => ip.is_private(), - _ => false, + IpAddr::V6(_) => false, } } diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index bb02a476fc..88caa9a316 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -62,8 +62,8 @@ pub struct ParquetUploadArgs { // But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN // level instead, as repeated failures can mean a more serious problem. If it // fails more than FAILED_UPLOAD_RETRIES times, we give up -pub const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3; -pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; +pub(crate) const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3; +pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; // the parquet crate leaves a lot to be desired... // what follows is an attempt to write parquet files with minimal allocs. @@ -73,7 +73,7 @@ pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; // * after each rowgroup write, we check the length of the file and upload to s3 if large enough #[derive(parquet_derive::ParquetRecordWriter)] -pub struct RequestData { +pub(crate) struct RequestData { region: &'static str, protocol: &'static str, /// Must be UTC. The derive macro doesn't like the timezones @@ -736,7 +736,7 @@ mod tests { while let Some(r) = s.next().await { tx.send(r).unwrap(); } - time::sleep(time::Duration::from_secs(70)).await + time::sleep(time::Duration::from_secs(70)).await; } }); diff --git a/proxy/src/error.rs b/proxy/src/error.rs index fdfe50a494..53f9f75c5b 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -3,12 +3,12 @@ use std::{error::Error as StdError, fmt, io}; use measured::FixedCardinalityLabel; /// Upcast (almost) any error into an opaque [`io::Error`]. -pub fn io_error(e: impl Into>) -> io::Error { +pub(crate) fn io_error(e: impl Into>) -> io::Error { io::Error::new(io::ErrorKind::Other, e) } /// A small combinator for pluggable error logging. -pub fn log_error(e: E) -> E { +pub(crate) fn log_error(e: E) -> E { tracing::error!("{e}"); e } @@ -19,7 +19,7 @@ pub fn log_error(e: E) -> E { /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: ReportableError { +pub(crate) trait UserFacingError: ReportableError { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -64,7 +64,7 @@ pub enum ErrorKind { } impl ErrorKind { - pub fn to_metric_label(&self) -> &'static str { + pub(crate) fn to_metric_label(self) -> &'static str { match self { ErrorKind::User => "user", ErrorKind::ClientDisconnect => "clientdisconnect", @@ -78,7 +78,7 @@ impl ErrorKind { } } -pub trait ReportableError: fmt::Display + Send + 'static { +pub(crate) trait ReportableError: fmt::Display + Send + 'static { fn get_error_kind(&self) -> ErrorKind; } diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 1f1dd8c415..fee634f67f 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -12,9 +12,9 @@ use http_body_util::BodyExt; use hyper1::body::Body; use serde::de::DeserializeOwned; -pub use reqwest::{Request, Response, StatusCode}; -pub use reqwest_middleware::{ClientWithMiddleware, Error}; -pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +pub(crate) use reqwest::{Request, Response}; +pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error}; +pub(crate) use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use crate::{ metrics::{ConsoleRequest, Metrics}, @@ -35,7 +35,7 @@ pub fn new_client() -> ClientWithMiddleware { .build() } -pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { +pub(crate) fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { let timeout_client = reqwest::ClientBuilder::new() .timeout(default_timout) .build() @@ -77,20 +77,20 @@ impl Endpoint { } #[inline(always)] - pub fn url(&self) -> &ApiUrl { + pub(crate) fn url(&self) -> &ApiUrl { &self.endpoint } /// Return a [builder](RequestBuilder) for a `GET` request, /// appending a single `path` segment to the base endpoint URL. - pub fn get(&self, path: &str) -> RequestBuilder { + pub(crate) fn get(&self, path: &str) -> RequestBuilder { let mut url = self.endpoint.clone(); url.path_segments_mut().push(path); self.client.get(url.into_inner()) } /// Execute a [request](reqwest::Request). - pub async fn execute(&self, request: Request) -> Result { + pub(crate) async fn execute(&self, request: Request) -> Result { let _timer = Metrics::get() .proxy .console_request_latency @@ -102,7 +102,7 @@ impl Endpoint { } } -pub async fn parse_json_body_with_limit( +pub(crate) async fn parse_json_body_with_limit( mut b: impl Body + Unpin, limit: usize, ) -> anyhow::Result { diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index e38135dd22..e5144cfe2e 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -29,10 +29,10 @@ impl std::fmt::Display for InternedString { } impl InternedString { - pub fn as_str(&self) -> &'static str { + pub(crate) fn as_str(&self) -> &'static str { Id::get_interner().inner.resolve(&self.inner) } - pub fn get(s: &str) -> Option { + pub(crate) fn get(s: &str) -> Option { Id::get_interner().get(s) } } @@ -56,7 +56,7 @@ impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString { impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor { type Value = InternedString; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { formatter.write_str("a string") } @@ -78,7 +78,7 @@ impl serde::Serialize for InternedString { } impl StringInterner { - pub fn new() -> Self { + pub(crate) fn new() -> Self { StringInterner { inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher( Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()), @@ -90,26 +90,24 @@ impl StringInterner { } } - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - pub fn len(&self) -> usize { + #[cfg(test)] + fn len(&self) -> usize { self.inner.len() } - pub fn current_memory_usage(&self) -> usize { + #[cfg(test)] + fn current_memory_usage(&self) -> usize { self.inner.current_memory_usage() } - pub fn get_or_intern(&self, s: &str) -> InternedString { + pub(crate) fn get_or_intern(&self, s: &str) -> InternedString { InternedString { inner: self.inner.get_or_intern(s), _id: PhantomData, } } - pub fn get(&self, s: &str) -> Option> { + pub(crate) fn get(&self, s: &str) -> Option> { Some(InternedString { inner: self.inner.get(s)?, _id: PhantomData, @@ -132,14 +130,14 @@ impl Default for StringInterner { } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub struct RoleNameTag; +pub(crate) struct RoleNameTag; impl InternId for RoleNameTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } -pub type RoleNameInt = InternedString; +pub(crate) type RoleNameInt = InternedString; impl From<&RoleName> for RoleNameInt { fn from(value: &RoleName) -> Self { RoleNameTag::get_interner().get_or_intern(value) @@ -150,7 +148,7 @@ impl From<&RoleName> for RoleNameInt { pub struct EndpointIdTag; impl InternId for EndpointIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -170,7 +168,7 @@ impl From for EndpointIdInt { pub struct BranchIdTag; impl InternId for BranchIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -190,7 +188,7 @@ impl From for BranchIdInt { pub struct ProjectIdTag; impl InternId for ProjectIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -217,7 +215,7 @@ mod tests { struct MyId; impl InternId for MyId { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + pub(crate) static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ea92eaaa55..8d7e586b3d 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -1,6 +1,83 @@ -#![deny(clippy::undocumented_unsafe_blocks)] +// rustc lints/lint groups +// https://doc.rust-lang.org/rustc/lints/groups.html +#![deny( + deprecated, + future_incompatible, + // TODO: consider let_underscore + nonstandard_style, + rust_2024_compatibility +)] +#![warn(clippy::all, clippy::pedantic, clippy::cargo)] +// List of denied lints from the clippy::restriction group. +// https://rust-lang.github.io/rust-clippy/master/index.html#?groups=restriction +#![warn( + clippy::undocumented_unsafe_blocks, + // TODO: Enable once all individual checks are enabled. + //clippy::as_conversions, + clippy::dbg_macro, + clippy::empty_enum_variants_with_brackets, + clippy::exit, + clippy::float_cmp_const, + clippy::lossy_float_literal, + clippy::macro_use_imports, + clippy::manual_ok_or, + // TODO: consider clippy::map_err_ignore + // TODO: consider clippy::mem_forget + clippy::rc_mutex, + clippy::rest_pat_in_fully_bound_structs, + clippy::string_add, + clippy::string_to_string, + clippy::todo, + // TODO: consider clippy::unimplemented + // TODO: consider clippy::unwrap_used +)] +// List of permanently allowed lints. +#![allow( + // It's ok to cast bool to u8, etc. + clippy::cast_lossless, + // Seems unavoidable. + clippy::multiple_crate_versions, + // While #[must_use] is a great feature this check is too noisy. + clippy::must_use_candidate, + // Inline consts, structs, fns, imports, etc. are ok if they're used by + // the following statement(s). + clippy::items_after_statements, +)] +// List of temporarily allowed lints. +// TODO: Switch to except() once stable with 1.81. +// TODO: fix code and reduce list or move to permanent list above. +#![allow( + clippy::cargo_common_metadata, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_precision_loss, + clippy::cast_sign_loss, + clippy::doc_markdown, + clippy::implicit_hasher, + clippy::inline_always, + clippy::match_same_arms, + clippy::match_wild_err_arm, + clippy::missing_errors_doc, + clippy::missing_panics_doc, + clippy::module_name_repetitions, + clippy::needless_pass_by_value, + clippy::needless_raw_string_hashes, + clippy::redundant_closure_for_method_calls, + clippy::return_self_not_must_use, + clippy::similar_names, + clippy::single_match_else, + clippy::struct_excessive_bools, + clippy::struct_field_names, + clippy::too_many_lines, + clippy::unreadable_literal, + clippy::unused_async, + clippy::unused_self, + clippy::wildcard_imports +)] +// List of temporarily allowed lints to unblock beta/nightly. +#![allow(unknown_lints, clippy::manual_inspect)] -use std::convert::Infallible; +use std::{convert::Infallible, future::Future}; use anyhow::{bail, Context}; use intern::{EndpointIdInt, EndpointIdTag, InternId}; @@ -35,7 +112,14 @@ pub mod usage_metrics; pub mod waiters; /// Handle unix signals appropriately. -pub async fn handle_signals(token: CancellationToken) -> anyhow::Result { +pub async fn handle_signals( + token: CancellationToken, + mut refresh_config: F, +) -> anyhow::Result +where + F: FnMut() -> Fut, + Fut: Future>, +{ use tokio::signal::unix::{signal, SignalKind}; let mut hangup = signal(SignalKind::hangup())?; @@ -46,7 +130,8 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result { - warn!("received SIGHUP; config reload is not supported"); + warn!("received SIGHUP"); + refresh_config().await?; } // Shut down the whole application. _ = interrupt.recv() => { @@ -72,7 +157,8 @@ macro_rules! smol_str_wrapper { pub struct $name(smol_str::SmolStr); impl $name { - pub fn as_str(&self) -> &str { + #[allow(unused)] + pub(crate) fn as_str(&self) -> &str { self.0.as_str() } } @@ -167,19 +253,19 @@ smol_str_wrapper!(Host); // Endpoints are a bit tricky. Rare they might be branches or projects. impl EndpointId { - pub fn is_endpoint(&self) -> bool { + pub(crate) fn is_endpoint(&self) -> bool { self.0.starts_with("ep-") } - pub fn is_branch(&self) -> bool { + pub(crate) fn is_branch(&self) -> bool { self.0.starts_with("br-") } - pub fn is_project(&self) -> bool { - !self.is_endpoint() && !self.is_branch() - } - pub fn as_branch(&self) -> BranchId { + // pub(crate) fn is_project(&self) -> bool { + // !self.is_endpoint() && !self.is_branch() + // } + pub(crate) fn as_branch(&self) -> BranchId { BranchId(self.0.clone()) } - pub fn as_project(&self) -> ProjectId { + pub(crate) fn as_project(&self) -> ProjectId { ProjectId(self.0.clone()) } } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 0167553e30..ccef88231b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -252,7 +252,7 @@ impl Drop for HttpEndpointPoolsGuard<'_> { } impl HttpEndpointPools { - pub fn guard(&self) -> HttpEndpointPoolsGuard { + pub fn guard(&self) -> HttpEndpointPoolsGuard<'_> { self.http_pool_endpoints_registered_total.inc(); HttpEndpointPoolsGuard { dec: &self.http_pool_endpoints_unregistered_total, diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs index 0d03574901..8c0f251066 100644 --- a/proxy/src/parse.rs +++ b/proxy/src/parse.rs @@ -2,14 +2,14 @@ use std::ffi::CStr; -pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { +pub(crate) fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { let cstr = CStr::from_bytes_until_nul(bytes).ok()?; let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len()); Some((cstr, other)) } /// See . -pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { +pub(crate) fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { (bytes.len() >= N).then(|| { let (head, tail) = bytes.split_at(N); (head.try_into().unwrap(), tail) diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 1dd4563514..17764f78d1 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -13,9 +13,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; pin_project! { /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough - pub struct ChainRW { + pub(crate) struct ChainRW { #[pin] - pub inner: T, + pub(crate) inner: T, buf: BytesMut, } } @@ -60,7 +60,7 @@ const HEADER: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; -pub async fn read_proxy_protocol( +pub(crate) async fn read_proxy_protocol( mut read: T, ) -> std::io::Result<(ChainRW, Option)> { let mut buf = BytesMut::with_capacity(128); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 2182f38fe7..ff199ac701 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,12 +1,12 @@ #[cfg(test)] mod tests; -pub mod connect_compute; +pub(crate) mod connect_compute; mod copy_bidirectional; -pub mod handshake; -pub mod passthrough; -pub mod retry; -pub mod wake_compute; +pub(crate) mod handshake; +pub(crate) mod passthrough; +pub(crate) mod retry; +pub(crate) mod wake_compute; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; @@ -170,21 +170,21 @@ pub async fn task_main( Ok(()) } -pub enum ClientMode { +pub(crate) enum ClientMode { Tcp, Websockets { hostname: Option }, } /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - pub fn allow_cleartext(&self) -> bool { + pub(crate) fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub(crate) fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -213,7 +213,7 @@ impl ClientMode { // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, // we cannot be sure the client even understands our error message // 3. PrepareClient: The client disconnected, so we can't tell them anyway... -pub enum ClientRequestError { +pub(crate) enum ClientRequestError { #[error("{0}")] Cancellation(#[from] cancellation::CancelError), #[error("{0}")] @@ -238,7 +238,7 @@ impl ReportableError for ClientRequestError { } } -pub async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, ctx: &RequestMonitoring, cancellation_handler: Arc, @@ -254,7 +254,7 @@ pub async fn handle_client( let metrics = &Metrics::get().proxy; let proto = ctx.protocol(); - let _request_gauge = metrics.connection_requests.guard(proto); + let request_gauge = metrics.connection_requests.guard(proto); let tls = config.tls_config.as_ref(); @@ -283,7 +283,7 @@ pub async fn handle_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) + .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); let user_info = match result { @@ -340,9 +340,9 @@ pub async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, - req: _request_gauge, - conn: conn_gauge, - cancel: session, + _req: request_gauge, + _conn: conn_gauge, + _cancel: session, })) } @@ -377,20 +377,20 @@ async fn prepare_client_connection

( } #[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); +pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { - pub fn parse_params(params: &StartupMessageParams) -> Self { + pub(crate) fn parse_params(params: &StartupMessageParams) -> Self { params .options_raw() .map(Self::parse_from_iter) .unwrap_or_default() } - pub fn parse_options_raw(options: &str) -> Self { + pub(crate) fn parse_options_raw(options: &str) -> Self { Self::parse_from_iter(StartupMessageParams::parse_options_raw(options)) } - pub fn is_ephemeral(&self) -> bool { + pub(crate) fn is_ephemeral(&self) -> bool { // Currently, neon endpoint options are all reserved for ephemeral endpoints. !self.0.is_empty() } @@ -404,7 +404,7 @@ impl NeonOptions { Self(options) } - pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { + pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable std::iter::once(prefix) @@ -415,7 +415,7 @@ impl NeonOptions { /// DeepObject format /// `paramName[prop1]=value1¶mName[prop2]=value2&...` - pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { + pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { self.0 .iter() .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone())) @@ -423,7 +423,7 @@ impl NeonOptions { } } -pub fn neon_option(bytes: &str) -> Option<(&str, &str)> { +pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> { static RE: OnceCell = OnceCell::new(); let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap()); diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index f38e43ba5a..613548d4a0 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -25,14 +25,15 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); } - let label = match is_cached { - true => ConnectionFailureKind::ComputeCached, - false => ConnectionFailureKind::ComputeUncached, + let label = if is_cached { + ConnectionFailureKind::ComputeCached + } else { + ConnectionFailureKind::ComputeUncached }; Metrics::get().proxy.connection_failures_total.inc(label); @@ -40,7 +41,7 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { } #[async_trait] -pub trait ConnectMechanism { +pub(crate) trait ConnectMechanism { type Connection; type ConnectError: ReportableError; type Error: From; @@ -55,21 +56,21 @@ pub trait ConnectMechanism { } #[async_trait] -pub trait ComputeConnectBackend { +pub(crate) trait ComputeConnectBackend { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result; - fn get_keys(&self) -> Option<&ComputeCredentialKeys>; + fn get_keys(&self) -> &ComputeCredentialKeys; } -pub struct TcpMechanism<'a> { +pub(crate) struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. - pub params: &'a StartupMessageParams, + pub(crate) params: &'a StartupMessageParams, /// connect_to_compute concurrency lock - pub locks: &'static ApiLocks, + pub(crate) locks: &'static ApiLocks, } #[async_trait] @@ -97,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub(crate) async fn connect_to_compute( ctx: &RequestMonitoring, mechanism: &M, user_info: &B, @@ -112,9 +113,8 @@ where let mut num_retries = 0; let mut node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - if let Some(keys) = user_info.get_keys() { - node_info.set_keys(keys); - } + + node_info.set_keys(user_info.get_keys()); node_info.allow_self_signed_compute = allow_self_signed_compute; // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 3c45fff969..4ebda013ac 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -14,7 +14,7 @@ enum TransferState { } #[derive(Debug)] -pub enum ErrorDirection { +pub(crate) enum ErrorDirection { Read(io::Error), Write(io::Error), } @@ -184,7 +184,7 @@ impl CopyBuffer { } Poll::Pending } - res => res.map_err(ErrorDirection::Write), + res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write), } } @@ -230,11 +230,10 @@ impl CopyBuffer { io::ErrorKind::WriteZero, "write zero byte into writer", )))); - } else { - self.pos += i; - self.amt += i as u64; - self.need_flush = true; } + self.pos += i; + self.amt += i as u64; + self.need_flush = true; } // If pos larger than cap, this loop will never stop. diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index c65a5558d9..5996b11c11 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -18,7 +18,7 @@ use crate::{ }; #[derive(Error, Debug)] -pub enum HandshakeError { +pub(crate) enum HandshakeError { #[error("data is sent before server replied with EncryptionResponse")] EarlyData, @@ -57,7 +57,7 @@ impl ReportableError for HandshakeError { } } -pub enum HandshakeData { +pub(crate) enum HandshakeData { Startup(PqStream>, StartupMessageParams), Cancel(CancelKeyData), } @@ -67,7 +67,7 @@ pub enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub async fn handshake( +pub(crate) async fn handshake( ctx: &RequestMonitoring, stream: S, mut tls: Option<&TlsConfig>, @@ -82,9 +82,8 @@ pub async fn handshake( let mut stream = PqStream::new(Stream::from_raw(stream)); loop { let msg = stream.read_startup_packet().await?; - use FeStartupPacket::*; match msg { - SslRequest { direct } => match stream.get_ref() { + FeStartupPacket::SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; @@ -139,7 +138,7 @@ pub async fn handshake( let tls_stream = accept.await.inspect_err(|_| { if record_handshake_error { - Metrics::get().proxy.tls_handshake_failures.inc() + Metrics::get().proxy.tls_handshake_failures.inc(); } })?; @@ -182,7 +181,7 @@ pub async fn handshake( } _ => return Err(HandshakeError::ProtocolViolation), }, - GssEncRequest => match stream.get_ref() { + FeStartupPacket::GssEncRequest => match stream.get_ref() { Stream::Raw { .. } if !tried_gss => { tried_gss = true; @@ -191,7 +190,7 @@ pub async fn handshake( } _ => return Err(HandshakeError::ProtocolViolation), }, - StartupMessage { params, version } + FeStartupPacket::StartupMessage { params, version } if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST => { // Check that the config has been consumed during upgrade @@ -211,7 +210,7 @@ pub async fn handshake( break Ok(HandshakeData::Startup(stream, params)); } // downgrade protocol version - StartupMessage { params, version } + FeStartupPacket::StartupMessage { params, version } if version.major() == 3 && version > PG_PROTOCOL_LATEST => { warn!(?version, "unsupported minor version"); @@ -241,7 +240,7 @@ pub async fn handshake( ); break Ok(HandshakeData::Startup(stream, params)); } - StartupMessage { version, .. } => { + FeStartupPacket::StartupMessage { version, .. } => { warn!( ?version, session_type = "normal", @@ -249,7 +248,7 @@ pub async fn handshake( ); return Err(HandshakeError::ProtocolViolation); } - CancelRequest(cancel_key_data) => { + FeStartupPacket::CancelRequest(cancel_key_data) => { info!(session_type = "cancellation", "successful handshake"); break Ok(HandshakeData::Cancel(cancel_key_data)); } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 9942fac383..c17108de0a 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -14,7 +14,7 @@ use super::copy_bidirectional::ErrorSource; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] -pub async fn proxy_pass( +pub(crate) async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, @@ -57,18 +57,18 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { - pub client: Stream, - pub compute: PostgresConnection, - pub aux: MetricsAuxInfo, +pub(crate) struct ProxyPassthrough { + pub(crate) client: Stream, + pub(crate) compute: PostgresConnection, + pub(crate) aux: MetricsAuxInfo, - pub req: NumConnectionRequestsGuard<'static>, - pub conn: NumClientConnectionsGuard<'static>, - pub cancel: cancellation::Session

, + pub(crate) _req: NumConnectionRequestsGuard<'static>, + pub(crate) _conn: NumClientConnectionsGuard<'static>, + pub(crate) _cancel: cancellation::Session

, } impl ProxyPassthrough { - pub async fn proxy_pass(self) -> Result<(), ErrorSource> { + pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { tracing::error!(?err, "could not cancel the query in the database"); diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 644b183a91..15895d37e6 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -2,18 +2,18 @@ use crate::{compute, config::RetryConfig}; use std::{error::Error, io}; use tokio::time; -pub trait CouldRetry { +pub(crate) trait CouldRetry { /// Returns true if the error could be retried fn could_retry(&self) -> bool; } -pub trait ShouldRetryWakeCompute { +pub(crate) trait ShouldRetryWakeCompute { /// Returns true if we need to invalidate the cache for this node. /// If false, we can continue retrying with the current node cache. fn should_retry_wake_compute(&self) -> bool; } -pub fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool { +pub(crate) fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool { num_retries < config.max_retries && err.could_retry() } @@ -101,7 +101,7 @@ impl ShouldRetryWakeCompute for compute::ConnectionError { } } -pub fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration { +pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration { config .base_delay .mul_f64(config.backoff_factor.powi((num_retries as i32) - 1)) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index d8308c4f2a..4264dbae0f 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -11,14 +11,14 @@ use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, }; use crate::config::{CertResolver, RetryConfig}; -use crate::console::caches::NodeInfoCache; use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; +use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeInfoCache}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; -use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; +use crate::{sasl, scram, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; use async_trait::async_trait; +use http::StatusCode; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::pki_types; @@ -433,7 +433,7 @@ impl ReportableError for TestConnectError { impl std::fmt::Display for TestConnectError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } @@ -475,7 +475,7 @@ impl ConnectMechanism for TestConnectMechanism { retryable: false, kind: ErrorKind::Compute, }), - x => panic!("expecting action {:?}, connect is called instead", x), + x => panic!("expecting action {x:?}, connect is called instead"), } } @@ -491,7 +491,7 @@ impl TestBackend for TestConnectMechanism { ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console(ConsoleError { - http_status_code: http::StatusCode::BAD_REQUEST, + http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: None, }); @@ -500,7 +500,7 @@ impl TestBackend for TestConnectMechanism { } ConnectAction::WakeRetry => { let err = console::errors::ApiError::Console(ConsoleError { - http_status_code: http::StatusCode::BAD_REQUEST, + http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: Some(Status { code: "error".into(), @@ -515,7 +515,7 @@ impl TestBackend for TestConnectMechanism { assert!(err.could_retry()); Err(console::errors::WakeComputeError::ApiError(err)) } - x => panic!("expecting action {:?}, wake_compute is called instead", x), + x => panic!("expecting action {x:?}, wake_compute is called instead"), } } @@ -525,9 +525,6 @@ impl TestBackend for TestConnectMechanism { { unimplemented!("not used in tests") } - fn get_role_secret(&self) -> Result { - unimplemented!("not used in tests") - } } fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { @@ -547,8 +544,8 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> auth::BackendType<'static, ComputeCredentials, &()> { - let user_info = auth::BackendType::Console( +) -> auth::Backend<'static, ComputeCredentials, &()> { + let user_info = auth::Backend::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), ComputeCredentials { info: ComputeUserInfo { diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index c8ec2b2db6..33a2162bc7 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -68,7 +68,7 @@ async fn proxy_mitm( end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap(); continue; } - end_client.send(message).await.unwrap() + end_client.send(message).await.unwrap(); } _ => break, } @@ -88,7 +88,7 @@ async fn proxy_mitm( end_server.send(buf.freeze()).await.unwrap(); continue; } - end_server.send(message).await.unwrap() + end_server.send(message).await.unwrap(); } _ => break, } @@ -102,7 +102,7 @@ async fn proxy_mitm( } /// taken from tokio-postgres -pub async fn connect_tls(mut stream: S, tls: T) -> T::Stream +pub(crate) async fn connect_tls(mut stream: S, tls: T) -> T::Stream where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, @@ -115,9 +115,7 @@ where let mut buf = [0]; stream.read_exact(&mut buf).await.unwrap(); - if buf[0] != b'S' { - panic!("ssl not supported by server"); - } + assert!(buf[0] == b'S', "ssl not supported by server"); tls.connect(stream).await.unwrap() } diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 5b06e8f054..9b8ac6d29d 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -12,7 +12,7 @@ use tracing::{error, info, warn}; use super::connect_compute::ComputeConnectBackend; -pub async fn wake_compute( +pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestMonitoring, api: &B, diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 222cd431d2..e5f5867998 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -1,10 +1,16 @@ +mod leaky_bucket; mod limit_algorithm; mod limiter; -pub use limit_algorithm::{ - aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, + +#[cfg(test)] +pub(crate) use limit_algorithm::aimd::Aimd; + +pub(crate) use limit_algorithm::{ + DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; -mod leaky_bucket; +pub(crate) use limiter::GlobalRateLimiter; + pub use leaky_bucket::{ EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState, }; +pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 2d5e056540..fa8cb75256 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -35,7 +35,7 @@ impl LeakyBucketRateLimiter { } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, n: u32) -> bool { let now = Instant::now(); if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { @@ -73,8 +73,9 @@ pub struct LeakyBucketState { time: Instant, } +#[cfg(test)] impl LeakyBucketConfig { - pub fn new(rps: f64, max: f64) -> Self { + pub(crate) fn new(rps: f64, max: f64) -> Self { assert!(rps > 0.0, "rps must be positive"); assert!(max > 0.0, "max must be positive"); Self { rps, max } @@ -82,7 +83,7 @@ impl LeakyBucketConfig { } impl LeakyBucketState { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { filled: 0.0, time: Instant::now(), @@ -100,7 +101,7 @@ impl LeakyBucketState { self.filled == 0.0 } - pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool { + pub(crate) fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool { self.update(info, now); if self.filled + n > info.max { @@ -119,6 +120,7 @@ impl Default for LeakyBucketState { } #[cfg(test)] +#[allow(clippy::float_cmp)] mod tests { use std::time::Duration; diff --git a/proxy/src/rate_limiter/limit_algorithm.rs b/proxy/src/rate_limiter/limit_algorithm.rs index 3842ce269e..25607b7e10 100644 --- a/proxy/src/rate_limiter/limit_algorithm.rs +++ b/proxy/src/rate_limiter/limit_algorithm.rs @@ -8,13 +8,13 @@ use tokio::{ use self::aimd::Aimd; -pub mod aimd; +pub(crate) mod aimd; /// Whether a job succeeded or failed as a result of congestion/overload. /// /// Errors not considered to be caused by overload should be ignored. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Outcome { +pub(crate) enum Outcome { /// The job succeeded, or failed in a way unrelated to overload. Success, /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal @@ -23,14 +23,14 @@ pub enum Outcome { } /// An algorithm for controlling a concurrency limit. -pub trait LimitAlgorithm: Send + Sync + 'static { +pub(crate) trait LimitAlgorithm: Send + Sync + 'static { /// Update the concurrency limit in response to a new job completion. fn update(&self, old_limit: usize, sample: Sample) -> usize; } /// The result of a job (or jobs), including the [`Outcome`] (loss) and latency (delay). #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub struct Sample { +pub(crate) struct Sample { pub(crate) latency: Duration, /// Jobs in flight when the sample was taken. pub(crate) in_flight: usize, @@ -39,7 +39,7 @@ pub struct Sample { #[derive(Clone, Copy, Debug, Default, serde::Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] -pub enum RateLimitAlgorithm { +pub(crate) enum RateLimitAlgorithm { #[default] Fixed, Aimd { @@ -48,7 +48,7 @@ pub enum RateLimitAlgorithm { }, } -pub struct Fixed; +pub(crate) struct Fixed; impl LimitAlgorithm for Fixed { fn update(&self, old_limit: usize, _sample: Sample) -> usize { @@ -59,12 +59,12 @@ impl LimitAlgorithm for Fixed { #[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)] pub struct RateLimiterConfig { #[serde(flatten)] - pub algorithm: RateLimitAlgorithm, - pub initial_limit: usize, + pub(crate) algorithm: RateLimitAlgorithm, + pub(crate) initial_limit: usize, } impl RateLimiterConfig { - pub fn create_rate_limit_algorithm(self) -> Box { + pub(crate) fn create_rate_limit_algorithm(self) -> Box { match self.algorithm { RateLimitAlgorithm::Fixed => Box::new(Fixed), RateLimitAlgorithm::Aimd { conf } => Box::new(conf), @@ -72,7 +72,7 @@ impl RateLimiterConfig { } } -pub struct LimiterInner { +pub(crate) struct LimiterInner { alg: Box, available: usize, limit: usize, @@ -114,7 +114,7 @@ impl LimiterInner { /// /// The limit will be automatically adjusted based on observed latency (delay) and/or failures /// caused by overload (loss). -pub struct DynamicLimiter { +pub(crate) struct DynamicLimiter { config: RateLimiterConfig, inner: Mutex, // to notify when a token is available @@ -124,7 +124,7 @@ pub struct DynamicLimiter { /// A concurrency token, required to run a job. /// /// Release the token back to the [`DynamicLimiter`] after the job is complete. -pub struct Token { +pub(crate) struct Token { start: Instant, limiter: Option>, } @@ -133,14 +133,14 @@ pub struct Token { /// /// Not guaranteed to be consistent under high concurrency. #[derive(Debug, Clone, Copy)] -pub struct LimiterState { +#[cfg(test)] +struct LimiterState { limit: usize, - in_flight: usize, } impl DynamicLimiter { /// Create a limiter with a given limit control algorithm. - pub fn new(config: RateLimiterConfig) -> Arc { + pub(crate) fn new(config: RateLimiterConfig) -> Arc { let ready = Notify::new(); ready.notify_one(); @@ -157,7 +157,10 @@ impl DynamicLimiter { } /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available. - pub async fn acquire_timeout(self: &Arc, duration: Duration) -> Result { + pub(crate) async fn acquire_timeout( + self: &Arc, + duration: Duration, + ) -> Result { tokio::time::timeout(duration, self.acquire()).await? } @@ -174,9 +177,8 @@ impl DynamicLimiter { let mut inner = self.inner.lock(); if inner.take(&self.ready).is_some() { break Ok(Token::new(self.clone())); - } else { - notified.set(self.ready.notified()); } + notified.set(self.ready.notified()); } notified.as_mut().await; ready = true; @@ -209,12 +211,10 @@ impl DynamicLimiter { } /// The current state of the limiter. - pub fn state(&self) -> LimiterState { + #[cfg(test)] + fn state(&self) -> LimiterState { let inner = self.inner.lock(); - LimiterState { - limit: inner.limit, - in_flight: inner.in_flight, - } + LimiterState { limit: inner.limit } } } @@ -225,22 +225,22 @@ impl Token { limiter: Some(limiter), } } - pub fn disabled() -> Self { + pub(crate) fn disabled() -> Self { Self { start: Instant::now(), limiter: None, } } - pub fn is_disabled(&self) -> bool { + pub(crate) fn is_disabled(&self) -> bool { self.limiter.is_none() } - pub fn release(mut self, outcome: Outcome) { - self.release_mut(Some(outcome)) + pub(crate) fn release(mut self, outcome: Outcome) { + self.release_mut(Some(outcome)); } - pub fn release_mut(&mut self, outcome: Option) { + pub(crate) fn release_mut(&mut self, outcome: Option) { if let Some(limiter) = self.limiter.take() { limiter.release_inner(self.start, outcome); } @@ -249,17 +249,14 @@ impl Token { impl Drop for Token { fn drop(&mut self) { - self.release_mut(None) + self.release_mut(None); } } +#[cfg(test)] impl LimiterState { /// The current concurrency limit. - pub fn limit(&self) -> usize { + fn limit(self) -> usize { self.limit } - /// The number of jobs in flight. - pub fn in_flight(&self) -> usize { - self.in_flight - } } diff --git a/proxy/src/rate_limiter/limit_algorithm/aimd.rs b/proxy/src/rate_limiter/limit_algorithm/aimd.rs index b39740bb21..86b56e38fb 100644 --- a/proxy/src/rate_limiter/limit_algorithm/aimd.rs +++ b/proxy/src/rate_limiter/limit_algorithm/aimd.rs @@ -10,24 +10,23 @@ use super::{LimitAlgorithm, Outcome, Sample}; /// /// Reduces available concurrency by a factor when load-based errors are detected. #[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)] -pub struct Aimd { +pub(crate) struct Aimd { /// Minimum limit for AIMD algorithm. - pub min: usize, + pub(crate) min: usize, /// Maximum limit for AIMD algorithm. - pub max: usize, + pub(crate) max: usize, /// Decrease AIMD decrease by value in case of error. - pub dec: f32, + pub(crate) dec: f32, /// Increase AIMD increase by value in case of success. - pub inc: usize, + pub(crate) inc: usize, /// A threshold below which the limit won't be increased. - pub utilisation: f32, + pub(crate) utilisation: f32, } impl LimitAlgorithm for Aimd { fn update(&self, old_limit: usize, sample: Sample) -> usize { - use Outcome::*; match sample.outcome { - Success => { + Outcome::Success => { let utilisation = sample.in_flight as f32 / old_limit as f32; if utilisation > self.utilisation { @@ -42,7 +41,7 @@ impl LimitAlgorithm for Aimd { old_limit } } - Overload => { + Outcome::Overload => { let limit = old_limit as f32 * self.dec; // Floor instead of round, so the limit reduces even with small numbers. diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 5db4efed37..be529f174d 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -17,13 +17,13 @@ use tracing::info; use crate::intern::EndpointIdInt; -pub struct GlobalRateLimiter { +pub(crate) struct GlobalRateLimiter { data: Vec, info: Vec, } impl GlobalRateLimiter { - pub fn new(info: Vec) -> Self { + pub(crate) fn new(info: Vec) -> Self { Self { data: vec![ RateBucket { @@ -37,7 +37,7 @@ impl GlobalRateLimiter { } /// Check that number of connections is below `max_rps` rps. - pub fn check(&mut self) -> bool { + pub(crate) fn check(&mut self) -> bool { let now = Instant::now(); let should_allow_request = self @@ -96,9 +96,9 @@ impl RateBucket { #[derive(Clone, Copy, PartialEq)] pub struct RateBucketInfo { - pub interval: Duration, + pub(crate) interval: Duration, // requests per interval - pub max_rpi: u32, + pub(crate) max_rpi: u32, } impl std::fmt::Display for RateBucketInfo { @@ -192,7 +192,7 @@ impl BucketRateLimiter { } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, n: u32) -> bool { // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map. // worst case memory usage is about: // = 2 * 2048 * 64 * (48B + 72B) @@ -228,7 +228,7 @@ impl BucketRateLimiter { /// Clean the map. Simple strategy: remove all entries in a random shard. /// At worst, we'll double the effective max_rps during the cleanup. /// But that way deletion does not aquire mutex on each entry access. - pub fn do_gc(&self) { + pub(crate) fn do_gc(&self) { info!( "cleaning up bucket rate limiter, current size = {}", self.map.len() diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index c9a946fa4a..95bdfc0965 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -109,7 +109,7 @@ impl RedisPublisherClient { let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?; Ok(()) } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { + pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> { match self.client.connect().await { Ok(()) => {} Err(e) => { diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index b02ce472c0..7d222e2dec 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -81,7 +81,7 @@ impl ConnectionWithCredentialsProvider { redis::cmd("PING").query_async(con).await } - pub async fn connect(&mut self) -> anyhow::Result<()> { + pub(crate) async fn connect(&mut self) -> anyhow::Result<()> { let _guard = self.mutex.lock().await; if let Some(con) = self.con.as_mut() { match Self::ping(con).await { @@ -98,7 +98,7 @@ impl ConnectionWithCredentialsProvider { info!("Establishing a new connection..."); self.con = None; if let Some(f) = self.refresh_token_task.take() { - f.abort() + f.abort(); } let mut con = self .get_client() @@ -149,7 +149,7 @@ impl ConnectionWithCredentialsProvider { // PubSub does not support credentials refresh. // Requires manual reconnection every 12h. - pub async fn get_async_pubsub(&self) -> anyhow::Result { + pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result { Ok(self.get_client().await?.get_async_pubsub().await?) } @@ -187,7 +187,10 @@ impl ConnectionWithCredentialsProvider { } /// Sends an already encoded (packed) command into the TCP socket and /// reads the single response from it. - pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { + pub(crate) async fn send_packed_command( + &mut self, + cmd: &redis::Cmd, + ) -> RedisResult { // Clone connection to avoid having to lock the ArcSwap in write mode let con = self.con.as_mut().ok_or(redis::RedisError::from(( redis::ErrorKind::IoError, @@ -199,7 +202,7 @@ impl ConnectionWithCredentialsProvider { /// Sends multiple already encoded (packed) command into the TCP socket /// and reads `count` responses from it. This is used to implement /// pipelining. - pub async fn send_packed_commands( + pub(crate) async fn send_packed_commands( &mut self, cmd: &redis::Pipeline, offset: usize, diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs index eded8250af..d118c8f412 100644 --- a/proxy/src/redis/elasticache.rs +++ b/proxy/src/redis/elasticache.rs @@ -51,7 +51,7 @@ impl CredentialsProvider { credentials_provider, } } - pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { + pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { let aws_credentials = self .credentials_provider .provide_credentials() diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index efd7437d5d..36a3443603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -58,9 +58,9 @@ pub(crate) struct PasswordUpdate { } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub(crate) struct CancelSession { - pub region_id: Option, - pub cancel_key_data: CancelKeyData, - pub session_id: Uuid, + pub(crate) region_id: Option, + pub(crate) cancel_key_data: CancelKeyData, + pub(crate) session_id: Uuid, } fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result @@ -89,7 +89,7 @@ impl Clone for MessageHandler { } impl MessageHandler { - pub fn new( + pub(crate) fn new( cache: Arc, cancellation_handler: Arc>, region_id: String, @@ -100,15 +100,14 @@ impl MessageHandler { region_id, } } - pub async fn increment_active_listeners(&self) { + pub(crate) async fn increment_active_listeners(&self) { self.cache.increment_active_listeners().await; } - pub async fn decrement_active_listeners(&self) { + pub(crate) async fn decrement_active_listeners(&self) { self.cache.decrement_active_listeners().await; } #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { - use Notification::*; let payload: String = msg.get_payload()?; tracing::debug!(?payload, "received a message payload"); @@ -124,7 +123,7 @@ impl MessageHandler { }; tracing::debug!(?msg, "received a message"); match msg { - Cancel(cancel_session) => { + Notification::Cancel(cancel_session) => { tracing::Span::current().record( "session_id", tracing::field::display(cancel_session.session_id), @@ -151,14 +150,14 @@ impl MessageHandler { } } } - _ => { + Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => { invalidate_cache(self.cache.clone(), msg.clone()); - if matches!(msg, AllowedIpsUpdate { .. }) { + if matches!(msg, Notification::AllowedIpsUpdate { .. }) { Metrics::get() .proxy .redis_events_count .inc(RedisEventsCount::AllowedIpsUpdate); - } else if matches!(msg, PasswordUpdate { .. }) { + } else if matches!(msg, Notification::PasswordUpdate { .. }) { Metrics::get() .proxy .redis_events_count @@ -180,16 +179,16 @@ impl MessageHandler { } fn invalidate_cache(cache: Arc, msg: Notification) { - use Notification::*; match msg { - AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id) + Notification::AllowedIpsUpdate { allowed_ips_update } => { + cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); } - PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), - Cancel(_) => unreachable!("cancel message should be handled separately"), + Notification::PasswordUpdate { password_update } => cache + .invalidate_role_secret_for_project( + password_update.project_id, + password_update.role_name, + ), + Notification::Cancel(_) => unreachable!("cancel message should be handled separately"), } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 0811416ca2..0a36694359 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -14,13 +14,13 @@ use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; -pub use channel_binding::ChannelBinding; -pub use messages::FirstMessage; -pub use stream::{Outcome, SaslStream}; +pub(crate) use channel_binding::ChannelBinding; +pub(crate) use messages::FirstMessage; +pub(crate) use stream::{Outcome, SaslStream}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] -pub enum Error { +pub(crate) enum Error { #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -42,10 +42,9 @@ pub enum Error { impl UserFacingError for Error { fn to_string_client(&self) -> String { - use Error::*; match self { - ChannelBindingFailed(m) => m.to_string(), - ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), + Self::ChannelBindingFailed(m) => (*m).to_string(), + Self::ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), _ => "authentication protocol violation".to_string(), } } @@ -65,11 +64,11 @@ impl ReportableError for Error { } /// A convenient result type for SASL exchange. -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; /// A result of one SASL exchange. #[must_use] -pub enum Step { +pub(crate) enum Step { /// We should continue exchanging messages. Continue(T, String), /// The client has been authenticated successfully. @@ -79,7 +78,7 @@ pub enum Step { } /// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. -pub trait Mechanism: Sized { +pub(crate) trait Mechanism: Sized { /// What's produced as a result of successful authentication. type Output; diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs index 13d681de6d..fdd011448e 100644 --- a/proxy/src/sasl/channel_binding.rs +++ b/proxy/src/sasl/channel_binding.rs @@ -2,7 +2,7 @@ /// Channel binding flag (possibly with params). #[derive(Debug, PartialEq, Eq)] -pub enum ChannelBinding { +pub(crate) enum ChannelBinding { /// Client doesn't support channel binding. NotSupportedClient, /// Client thinks server doesn't support channel binding. @@ -12,45 +12,45 @@ pub enum ChannelBinding { } impl ChannelBinding { - pub fn and_then(self, f: impl FnOnce(T) -> Result) -> Result, E> { - use ChannelBinding::*; + pub(crate) fn and_then( + self, + f: impl FnOnce(T) -> Result, + ) -> Result, E> { Ok(match self { - NotSupportedClient => NotSupportedClient, - NotSupportedServer => NotSupportedServer, - Required(x) => Required(f(x)?), + Self::NotSupportedClient => ChannelBinding::NotSupportedClient, + Self::NotSupportedServer => ChannelBinding::NotSupportedServer, + Self::Required(x) => ChannelBinding::Required(f(x)?), }) } } impl<'a> ChannelBinding<&'a str> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { - use ChannelBinding::*; + pub(crate) fn parse(input: &'a str) -> Option { Some(match input { - "n" => NotSupportedClient, - "y" => NotSupportedServer, - other => Required(other.strip_prefix("p=")?), + "n" => Self::NotSupportedClient, + "y" => Self::NotSupportedServer, + other => Self::Required(other.strip_prefix("p=")?), }) } } impl ChannelBinding { /// Encode channel binding data as base64 for subsequent checks. - pub fn encode<'a, E>( + pub(crate) fn encode<'a, E>( &self, get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>, ) -> Result, E> { - use ChannelBinding::*; Ok(match self { - NotSupportedClient => { + Self::NotSupportedClient => { // base64::encode("n,,") "biws".into() } - NotSupportedServer => { + Self::NotSupportedServer => { // base64::encode("y,,") "eSws".into() } - Required(mode) => { + Self::Required(mode) => { use std::io::Write; let mut cbind_input = vec![]; write!(&mut cbind_input, "p={mode},,",).unwrap(); diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index b9208f6f1f..6c9a42b2db 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -5,16 +5,16 @@ use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). #[derive(Debug)] -pub struct FirstMessage<'a> { +pub(crate) struct FirstMessage<'a> { /// Authentication method, e.g. `"SCRAM-SHA-256"`. - pub method: &'a str, + pub(crate) method: &'a str, /// Initial client message. - pub message: &'a str, + pub(crate) message: &'a str, } impl<'a> FirstMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(bytes: &'a [u8]) -> Option { + pub(crate) fn parse(bytes: &'a [u8]) -> Option { let (method_cstr, tail) = split_cstr(bytes)?; let method = method_cstr.to_str().ok()?; @@ -42,10 +42,9 @@ pub(super) enum ServerMessage { impl<'a> ServerMessage<&'a str> { pub(super) fn to_reply(&self) -> BeMessage<'a> { - use BeAuthenticationSaslMessage::*; BeMessage::AuthenticationSasl(match self { - ServerMessage::Continue(s) => Continue(s.as_bytes()), - ServerMessage::Final(s) => Final(s.as_bytes()), + ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()), + ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()), }) } } diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 9115b0f61a..b6becd28e1 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; /// Abstracts away all peculiarities of the libpq's protocol. -pub struct SaslStream<'a, S> { +pub(crate) struct SaslStream<'a, S> { /// The underlying stream. stream: &'a mut PqStream, /// Current password message we received from client. @@ -17,7 +17,7 @@ pub struct SaslStream<'a, S> { } impl<'a, S> SaslStream<'a, S> { - pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { Self { stream, current: bytes::Bytes::new(), @@ -53,7 +53,7 @@ impl SaslStream<'_, S> { /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. #[must_use = "caller must explicitly check for success"] -pub enum Outcome { +pub(crate) enum Outcome { /// Authentication succeeded and produced some value. Success(R), /// Authentication failed (reason attached). @@ -63,7 +63,7 @@ pub enum Outcome { impl SaslStream<'_, S> { /// Perform SASL message exchange according to the underlying algorithm /// until user is either authenticated or denied access. - pub async fn authenticate( + pub(crate) async fn authenticate( mut self, mut mechanism: M, ) -> super::Result> { diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 862facb4e5..d058f1c3f8 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -15,9 +15,9 @@ mod secret; mod signature; pub mod threadpool; -pub use exchange::{exchange, Exchange}; -pub use key::ScramKey; -pub use secret::ServerSecret; +pub(crate) use exchange::{exchange, Exchange}; +pub(crate) use key::ScramKey; +pub(crate) use secret::ServerSecret; use hmac::{Hmac, Mac}; use sha2::{Digest, Sha256}; @@ -26,8 +26,8 @@ const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; /// A list of supported SCRAM methods. -pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256]; -pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; +pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256]; +pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; /// Decode base64 into array without any heap allocations fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { @@ -137,12 +137,12 @@ mod tests { #[tokio::test] async fn round_trip() { - run_round_trip_test("pencil", "pencil").await + run_round_trip_test("pencil", "pencil").await; } #[tokio::test] #[should_panic(expected = "password doesn't match")] async fn failure() { - run_round_trip_test("pencil", "eraser").await + run_round_trip_test("pencil", "eraser").await; } } diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs index e8e7ef5c86..255694b33e 100644 --- a/proxy/src/scram/countmin.rs +++ b/proxy/src/scram/countmin.rs @@ -2,7 +2,7 @@ use std::hash::Hash; /// estimator of hash jobs per second. /// -pub struct CountMinSketch { +pub(crate) struct CountMinSketch { // one for each depth hashers: Vec, width: usize, @@ -20,7 +20,7 @@ impl CountMinSketch { /// actual <= estimate /// estimate <= actual + ε * N with probability 1 - δ /// where N is the cardinality of the stream - pub fn with_params(epsilon: f64, delta: f64) -> Self { + pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self { CountMinSketch::new( (std::f64::consts::E / epsilon).ceil() as usize, (1.0_f64 / delta).ln().ceil() as usize, @@ -49,7 +49,7 @@ impl CountMinSketch { } } - pub fn inc_and_return(&mut self, t: &T, x: u32) -> u32 { + pub(crate) fn inc_and_return(&mut self, t: &T, x: u32) -> u32 { let mut min = u32::MAX; for row in 0..self.depth { let col = (self.hashers[row].hash_one(t) as usize) % self.width; @@ -61,7 +61,7 @@ impl CountMinSketch { min } - pub fn reset(&mut self) { + pub(crate) fn reset(&mut self) { self.buckets.clear(); self.buckets.resize(self.width * self.depth, 0); } @@ -98,8 +98,6 @@ mod tests { // q% of counts will be within p of the actual value let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q); - dbg!(sketch.buckets.len()); - // insert a bunch of entries in a random order let mut ids2 = ids.clone(); while !ids2.is_empty() { diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index d0adbc780e..7fdadc7038 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -56,14 +56,14 @@ enum ExchangeState { } /// Server's side of SCRAM auth algorithm. -pub struct Exchange<'a> { +pub(crate) struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, tls_server_end_point: config::TlsServerEndPoint, } impl<'a> Exchange<'a> { - pub fn new( + pub(crate) fn new( secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], tls_server_end_point: config::TlsServerEndPoint, @@ -101,7 +101,7 @@ async fn derive_client_key( make_key(b"Client Key").into() } -pub async fn exchange( +pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, secret: &ServerSecret, @@ -210,23 +210,23 @@ impl sasl::Mechanism for Exchange<'_> { type Output = super::ScramKey; fn exchange(mut self, input: &str) -> sasl::Result> { - use {sasl::Step::*, ExchangeState::*}; + use {sasl::Step, ExchangeState}; match &self.state { - Initial(init) => { + ExchangeState::Initial(init) => { match init.transition(self.secret, &self.tls_server_end_point, input)? { - Continue(sent, msg) => { - self.state = SaltSent(sent); - Ok(Continue(self, msg)) + Step::Continue(sent, msg) => { + self.state = ExchangeState::SaltSent(sent); + Ok(Step::Continue(self, msg)) } - Success(x, _) => match x {}, - Failure(msg) => Ok(Failure(msg)), + Step::Success(x, _) => match x {}, + Step::Failure(msg) => Ok(Step::Failure(msg)), } } - SaltSent(sent) => { + ExchangeState::SaltSent(sent) => { match sent.transition(self.secret, &self.tls_server_end_point, input)? { - Success(keys, msg) => Ok(Success(keys, msg)), - Continue(x, _) => match x {}, - Failure(msg) => Ok(Failure(msg)), + Step::Success(keys, msg) => Ok(Step::Success(keys, msg)), + Step::Continue(x, _) => match x {}, + Step::Failure(msg) => Ok(Step::Failure(msg)), } } } diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index 32a3dbd203..fe55ff493b 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -3,14 +3,14 @@ use subtle::ConstantTimeEq; /// Faithfully taken from PostgreSQL. -pub const SCRAM_KEY_LEN: usize = 32; +pub(crate) const SCRAM_KEY_LEN: usize = 32; /// One of the keys derived from the user's password. /// We use the same structure for all keys, i.e. /// `ClientKey`, `StoredKey`, and `ServerKey`. #[derive(Clone, Default, Eq, Debug)] #[repr(transparent)] -pub struct ScramKey { +pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } @@ -27,11 +27,11 @@ impl ConstantTimeEq for ScramKey { } impl ScramKey { - pub fn sha256(&self) -> Self { + pub(crate) fn sha256(&self) -> Self { super::sha256([self.as_ref()]).into() } - pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { + pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } } diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index cf677a3334..fd9e77764c 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -8,7 +8,7 @@ use std::fmt; use std::ops::Range; /// Faithfully taken from PostgreSQL. -pub const SCRAM_RAW_NONCE_LEN: usize = 18; +pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18; /// Although we ignore all extensions, we still have to validate the message. fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { @@ -27,18 +27,18 @@ fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option< } #[derive(Debug)] -pub struct ClientFirstMessage<'a> { +pub(crate) struct ClientFirstMessage<'a> { /// `client-first-message-bare`. - pub bare: &'a str, + pub(crate) bare: &'a str, /// Channel binding mode. - pub cbind_flag: ChannelBinding<&'a str>, + pub(crate) cbind_flag: ChannelBinding<&'a str>, /// Client nonce. - pub nonce: &'a str, + pub(crate) nonce: &'a str, } impl<'a> ClientFirstMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { + pub(crate) fn parse(input: &'a str) -> Option { let mut parts = input.split(','); let cbind_flag = ChannelBinding::parse(parts.next()?)?; @@ -59,7 +59,7 @@ impl<'a> ClientFirstMessage<'a> { // https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14 if !username.is_empty() { - tracing::warn!(username, "scram username provided, but is not expected") + tracing::warn!(username, "scram username provided, but is not expected"); // TODO(conrad): // return None; } @@ -77,7 +77,7 @@ impl<'a> ClientFirstMessage<'a> { } /// Build a response to [`ClientFirstMessage`]. - pub fn build_server_first_message( + pub(crate) fn build_server_first_message( &self, nonce: &[u8; SCRAM_RAW_NONCE_LEN], salt_base64: &str, @@ -89,7 +89,7 @@ impl<'a> ClientFirstMessage<'a> { write!(&mut message, "r={}", self.nonce).unwrap(); base64::encode_config_buf(nonce, base64::STANDARD, &mut message); let combined_nonce = 2..message.len(); - write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap(); + write!(&mut message, ",s={salt_base64},i={iterations}").unwrap(); // This design guarantees that it's impossible to create a // server-first-message without receiving a client-first-message @@ -101,20 +101,20 @@ impl<'a> ClientFirstMessage<'a> { } #[derive(Debug)] -pub struct ClientFinalMessage<'a> { +pub(crate) struct ClientFinalMessage<'a> { /// `client-final-message-without-proof`. - pub without_proof: &'a str, + pub(crate) without_proof: &'a str, /// Channel binding data (base64). - pub channel_binding: &'a str, + pub(crate) channel_binding: &'a str, /// Combined client & server nonce. - pub nonce: &'a str, + pub(crate) nonce: &'a str, /// Client auth proof. - pub proof: [u8; SCRAM_KEY_LEN], + pub(crate) proof: [u8; SCRAM_KEY_LEN], } impl<'a> ClientFinalMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { + pub(crate) fn parse(input: &'a str) -> Option { let (without_proof, proof) = input.rsplit_once(',')?; let mut parts = without_proof.split(','); @@ -135,9 +135,9 @@ impl<'a> ClientFinalMessage<'a> { } /// Build a response to [`ClientFinalMessage`]. - pub fn build_server_final_message( + pub(crate) fn build_server_final_message( &self, - signature_builder: SignatureBuilder, + signature_builder: SignatureBuilder<'_>, server_key: &ScramKey, ) -> String { let mut buf = String::from("v="); @@ -153,7 +153,7 @@ impl<'a> ClientFinalMessage<'a> { /// We need to keep a convenient representation of this /// message for the next authentication step. -pub struct OwnedServerFirstMessage { +pub(crate) struct OwnedServerFirstMessage { /// Owned `server-first-message`. message: String, /// Slice into `message`. @@ -163,13 +163,13 @@ pub struct OwnedServerFirstMessage { impl OwnedServerFirstMessage { /// Extract combined nonce from the message. #[inline(always)] - pub fn nonce(&self) -> &str { + pub(crate) fn nonce(&self) -> &str { &self.message[self.nonce.clone()] } /// Get reference to a text representation of the message. #[inline(always)] - pub fn as_str(&self) -> &str { + pub(crate) fn as_str(&self) -> &str { &self.message } } @@ -212,7 +212,7 @@ mod tests { #[test] fn parse_client_first_message_with_invalid_gs2_authz() { - assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none()) + assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none()); } #[test] diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index a803ba7e1b..d5ed9002ad 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -4,7 +4,7 @@ use hmac::{ }; use sha2::Sha256; -pub struct Pbkdf2 { +pub(crate) struct Pbkdf2 { hmac: Hmac, prev: GenericArray, hi: GenericArray, @@ -13,7 +13,7 @@ pub struct Pbkdf2 { // inspired from impl Pbkdf2 { - pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { let hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); @@ -33,11 +33,11 @@ impl Pbkdf2 { } } - pub fn cost(&self) -> u32 { + pub(crate) fn cost(&self) -> u32 { (self.iterations).clamp(0, 4096) } - pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { let Self { hmac, prev, @@ -84,6 +84,6 @@ mod tests { }; let expected = pbkdf2_hmac_array::(pass, salt, 600000); - assert_eq!(hash, expected) + assert_eq!(hash, expected); } } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 44c4f9e44a..8c6a08d432 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -8,22 +8,22 @@ use super::key::ScramKey; /// Server secret is produced from user's password, /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] -pub struct ServerSecret { +pub(crate) struct ServerSecret { /// Number of iterations for `PBKDF2` function. - pub iterations: u32, + pub(crate) iterations: u32, /// Salt used to hash user's password. - pub salt_base64: String, + pub(crate) salt_base64: String, /// Hashed `ClientKey`. - pub stored_key: ScramKey, + pub(crate) stored_key: ScramKey, /// Used by client to verify server's signature. - pub server_key: ScramKey, + pub(crate) server_key: ScramKey, /// Should auth fail no matter what? /// This is exactly the case for mocked secrets. - pub doomed: bool, + pub(crate) doomed: bool, } impl ServerSecret { - pub fn parse(input: &str) -> Option { + pub(crate) fn parse(input: &str) -> Option { // SCRAM-SHA-256$:$: let s = input.strip_prefix("SCRAM-SHA-256$")?; let (params, keys) = s.split_once('$')?; @@ -42,7 +42,7 @@ impl ServerSecret { Some(secret) } - pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { + pub(crate) fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { // constant time to not leak partial key match client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8) } @@ -50,7 +50,7 @@ impl ServerSecret { /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. - pub fn mock(nonce: [u8; 32]) -> Self { + pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. @@ -66,7 +66,7 @@ impl ServerSecret { /// Build a new server secret from the prerequisites. /// XXX: We only use this function in tests. #[cfg(test)] - pub async fn build(password: &str) -> Option { + pub(crate) async fn build(password: &str) -> Option { Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await) } } @@ -82,13 +82,7 @@ mod tests { let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns="; let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI="; - let secret = format!( - "SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}", - iterations = iterations, - salt = salt, - stored_key = stored_key, - server_key = server_key, - ); + let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}"); let parsed = ServerSecret::parse(&secret).unwrap(); assert_eq!(parsed.iterations, iterations); diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index 1c2811d757..d3255cf2ca 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -4,14 +4,14 @@ use super::key::{ScramKey, SCRAM_KEY_LEN}; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] -pub struct SignatureBuilder<'a> { - pub client_first_message_bare: &'a str, - pub server_first_message: &'a str, - pub client_final_message_without_proof: &'a str, +pub(crate) struct SignatureBuilder<'a> { + pub(crate) client_first_message_bare: &'a str, + pub(crate) server_first_message: &'a str, + pub(crate) client_final_message_without_proof: &'a str, } impl SignatureBuilder<'_> { - pub fn build(&self, key: &ScramKey) -> Signature { + pub(crate) fn build(&self, key: &ScramKey) -> Signature { let parts = [ self.client_first_message_bare.as_bytes(), b",", @@ -28,13 +28,13 @@ impl SignatureBuilder<'_> { /// produces `ClientKey` that we need for authentication. #[derive(Debug)] #[repr(transparent)] -pub struct Signature { +pub(crate) struct Signature { bytes: [u8; SCRAM_KEY_LEN], } impl Signature { /// Derive `ClientKey` from client's signature and proof. - pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + pub(crate) fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { // This is how the proof is calculated: // // 1. sha256(ClientKey) -> StoredKey diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 7701b869a3..262c6d146e 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -68,7 +68,7 @@ impl ThreadPool { pool } - pub fn spawn_job( + pub(crate) fn spawn_job( &self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2, @@ -222,12 +222,11 @@ fn thread_rt(pool: Arc, worker: Worker, index: usize) { } for i in 0.. { - let mut job = match worker + let Some(mut job) = worker .pop() .or_else(|| pool.steal(&mut rng, index, &worker)) - { - Some(job) => job, - None => continue 'wait, + else { + continue 'wait; }; pool.metrics @@ -270,7 +269,7 @@ fn thread_rt(pool: Arc, worker: Worker, index: usize) { .inc(ThreadPoolWorkerId(index)); // skip for now - worker.push(job) + worker.push(job); } } @@ -316,6 +315,6 @@ mod tests { 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, ]; - assert_eq!(actual, expected) + assert_eq!(actual, expected); } } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 115bef7375..84f98cb8ad 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -10,6 +10,7 @@ mod json; mod sql_over_http; mod websocket; +use async_trait::async_trait; use atomic_take::AtomicTake; use bytes::Bytes; pub use conn_pool::GlobalConnPoolOptions; @@ -24,10 +25,9 @@ use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; use rand::SeedableRng; -pub use reqwest_middleware::{ClientWithMiddleware, Error}; -pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::timeout; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; +use tokio_rustls::TlsAcceptor; use tokio_util::task::TaskTracker; use crate::cancellation::CancellationHandlerMain; @@ -41,14 +41,14 @@ use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; use std::net::{IpAddr, SocketAddr}; -use std::pin::pin; +use std::pin::{pin, Pin}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; use tracing::{error, info, warn, Instrument}; use utils::http::error::ApiError; -pub const SERVERLESS_DRIVER_SNI: &str = "api"; +pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, @@ -86,18 +86,18 @@ pub async fn task_main( config, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); - - let tls_config = match config.tls_config.as_ref() { - Some(config) => config, + let tls_acceptor: Arc = match config.tls_config.as_ref() { + Some(config) => { + let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config()); + // prefer http2, but support http/1.1 + tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Arc::new(tls_server_config) + } None => { - warn!("TLS config is missing, WebSocket Secure server will not be started"); - return Ok(()); + warn!("TLS config is missing"); + Arc::new(NoTls) } }; - let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config()); - // prefer http2, but support http/1.1 - tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into(); let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` @@ -120,7 +120,7 @@ pub async fn task_main( tracing::trace!("attempting to cancel a random connection"); if let Some(token) = config.http_config.cancel_set.take() { tracing::debug!("cancelling a random connection"); - token.cancel() + token.cancel(); } } @@ -176,16 +176,41 @@ pub async fn task_main( Ok(()) } +pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {} +impl AsyncReadWrite for T {} +pub(crate) type AsyncRW = Pin>; + +#[async_trait] +trait MaybeTlsAcceptor: Send + Sync + 'static { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result; +} + +#[async_trait] +impl MaybeTlsAcceptor for rustls::ServerConfig { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { + Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?)) + } +} + +struct NoTls; + +#[async_trait] +impl MaybeTlsAcceptor for NoTls { + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { + Ok(Box::pin(conn)) + } +} + /// Handles the TCP startup lifecycle. /// 1. Parses PROXY protocol V2 /// 2. Handles TLS handshake async fn connection_startup( config: &ProxyConfig, - tls_acceptor: TlsAcceptor, + tls_acceptor: Arc, session_id: uuid::Uuid, conn: TcpStream, peer_addr: SocketAddr, -) -> Option<(TlsStream>, IpAddr)> { +) -> Option<(AsyncRW, IpAddr)> { // handle PROXY protocol let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, @@ -198,7 +223,7 @@ async fn connection_startup( let peer_addr = peer.unwrap_or(peer_addr).ip(); let has_private_peer_addr = match peer_addr { IpAddr::V4(ip) => ip.is_private(), - _ => false, + IpAddr::V6(_) => false, }; info!(?session_id, %peer_addr, "accepted new TCP connection"); @@ -241,7 +266,7 @@ async fn connection_handler( cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, - conn: TlsStream>, + conn: AsyncRW, peer_addr: IpAddr, session_id: uuid::Uuid, ) { @@ -326,7 +351,9 @@ async fn request_handler( .map(|s| s.to_string()); // Check if the request is a websocket upgrade request. - if framed_websockets::upgrade::is_upgrade_request(&request) { + if config.http_config.accept_websockets + && framed_websockets::upgrade::is_upgrade_request(&request) + { let ctx = RequestMonitoring::new( session_id, peer_addr, @@ -378,7 +405,7 @@ async fn request_handler( .header("Access-Control-Allow-Origin", "*") .header( "Access-Control-Allow-Headers", - "Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", + "Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level", ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 295ea1a1c7..f24e0478be 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, + auth::{ + backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo}, + check_peer_addr_is_in_list, AuthError, + }, compute, config::{AuthenticationConfig, ProxyConfig}, console::{ @@ -24,30 +27,35 @@ use crate::{ Host, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; +use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool}; -pub struct PoolingBackend { - pub pool: Arc>, - pub config: &'static ProxyConfig, - pub endpoint_rate_limiter: Arc, +pub(crate) struct PoolingBackend { + pub(crate) pool: Arc>, + pub(crate) config: &'static ProxyConfig, + pub(crate) endpoint_rate_limiter: Arc, } impl PoolingBackend { - pub async fn authenticate( + pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, config: &AuthenticationConfig, - conn_info: &ConnInfo, + user_info: &ComputeUserInfo, + password: &[u8], ) -> Result { - let user_info = conn_info.user_info.clone(); - let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); + let user_info = user_info.clone(); + let backend = self + .config + .auth_backend + .as_ref() + .map(|()| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); } if !self .endpoint_rate_limiter - .check(conn_info.user_info.endpoint.clone().into(), 1) + .check(user_info.endpoint.clone().into(), 1) { return Err(AuthError::too_many_connections()); } @@ -70,14 +78,10 @@ impl PoolingBackend { return Err(AuthError::auth_failed(&*user_info.user)); } }; - let ep = EndpointIdInt::from(&conn_info.user_info.endpoint); - let auth_outcome = crate::auth::validate_password_and_exchange( - &config.thread_pool, - ep, - &conn_info.password, - secret, - ) - .await?; + let ep = EndpointIdInt::from(&user_info.endpoint); + let auth_outcome = + crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -85,7 +89,7 @@ impl PoolingBackend { } crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); - Err(AuthError::auth_failed(&*conn_info.user_info.user)) + Err(AuthError::auth_failed(&*user_info.user)) } }; res.map(|key| ComputeCredentials { @@ -94,23 +98,56 @@ impl PoolingBackend { }) } + pub(crate) async fn authenticate_with_jwt( + &self, + ctx: &RequestMonitoring, + user_info: &ComputeUserInfo, + jwt: &str, + ) -> Result { + match &self.config.auth_backend { + crate::auth::Backend::Console(_, ()) => { + Err(AuthError::auth_failed("JWT login is not yet supported")) + } + crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed( + "JWT login over web auth proxy is not supported", + )), + crate::auth::Backend::Local(cache) => { + cache + .jwks_cache + .check_jwt( + ctx, + user_info.endpoint.clone(), + user_info.user.clone(), + &StaticAuthRules, + jwt, + ) + .await + .map_err(|e| AuthError::auth_failed(e.to_string()))?; + Ok(ComputeCredentials { + info: user_info.clone(), + keys: crate::auth::backend::ComputeCredentialKeys::None, + }) + } + } + } + // Wake up the destination if needed. Code here is a bit involved because // we reuse the code from the usual proxy and we need to prepare few structures // that this code expects. #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] - pub async fn connect_to_compute( + pub(crate) async fn connect_to_compute( &self, ctx: &RequestMonitoring, conn_info: ConnInfo, keys: ComputeCredentials, force_new: bool, ) -> Result, HttpConnError> { - let maybe_client = if !force_new { - info!("pool: looking for an existing connection"); - self.pool.get(ctx, &conn_info)? - } else { + let maybe_client = if force_new { info!("pool: pool is disabled"); None + } else { + info!("pool: looking for an existing connection"); + self.pool.get(ctx, &conn_info)? }; if let Some(client) = maybe_client { @@ -119,7 +156,7 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.config.auth_backend.as_ref().map(|_| keys); + let backend = self.config.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -138,7 +175,7 @@ impl PoolingBackend { } #[derive(Debug, thiserror::Error)] -pub enum HttpConnError { +pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to compute")] @@ -232,10 +269,16 @@ impl ConnectMechanism for TokioMechanism { let mut config = (*node_info.config).clone(); let config = config .user(&self.conn_info.user_info.user) - .password(&*self.conn_info.password) .dbname(&self.conn_info.dbname) .connect_timeout(timeout); + match &self.conn_info.auth { + AuthData::Jwt(_) => {} + AuthData::Password(pw) => { + config.password(pw); + } + } + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let res = config.connect(tokio_postgres::NoTls).await; drop(pause); diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs index 390df7f4f7..7659745473 100644 --- a/proxy/src/serverless/cancel_set.rs +++ b/proxy/src/serverless/cancel_set.rs @@ -22,7 +22,7 @@ pub struct CancelSet { hasher: Hasher, } -pub struct CancelShard { +pub(crate) struct CancelShard { tokens: IndexMap, } @@ -40,7 +40,7 @@ impl CancelSet { } } - pub fn take(&self) -> Option { + pub(crate) fn take(&self) -> Option { for _ in 0..4 { if let Some(token) = self.take_raw(thread_rng().gen()) { return Some(token); @@ -50,12 +50,12 @@ impl CancelSet { None } - pub fn take_raw(&self, rng: usize) -> Option { + pub(crate) fn take_raw(&self, rng: usize) -> Option { NonZeroUsize::new(self.shards.len()) .and_then(|len| self.shards[rng % len].lock().take(rng / len)) } - pub fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> { + pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> { let shard = NonZeroUsize::new(self.shards.len()).map(|len| { let hash = self.hasher.hash_one(id) as usize; let shard = &self.shards[hash % len]; @@ -88,7 +88,7 @@ impl CancelShard { } } -pub struct CancelGuard<'a> { +pub(crate) struct CancelGuard<'a> { shard: Option<&'a Mutex>, id: Uuid, } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index e1dc44dc1c..bea599e9b9 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -30,19 +30,25 @@ use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; #[derive(Debug, Clone)] -pub struct ConnInfo { - pub user_info: ComputeUserInfo, - pub dbname: DbName, - pub password: SmallVec<[u8; 16]>, +pub(crate) struct ConnInfo { + pub(crate) user_info: ComputeUserInfo, + pub(crate) dbname: DbName, + pub(crate) auth: AuthData, +} + +#[derive(Debug, Clone)] +pub(crate) enum AuthData { + Password(SmallVec<[u8; 16]>), + Jwt(String), } impl ConnInfo { // hm, change to hasher to avoid cloning? - pub fn db_and_user(&self) -> (DbName, RoleName) { + pub(crate) fn db_and_user(&self) -> (DbName, RoleName) { (self.dbname.clone(), self.user_info.user.clone()) } - pub fn endpoint_cache_key(&self) -> Option { + pub(crate) fn endpoint_cache_key(&self) -> Option { // We don't want to cache http connections for ephemeral endpoints. if self.user_info.options.is_ephemeral() { None @@ -73,7 +79,7 @@ struct ConnPoolEntry { // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. -pub struct EndpointConnPool { +pub(crate) struct EndpointConnPool { pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, @@ -192,7 +198,7 @@ impl Drop for EndpointConnPool { } } -pub struct DbUserConnPool { +pub(crate) struct DbUserConnPool { conns: Vec>, } @@ -235,7 +241,7 @@ impl DbUserConnPool { } } -pub struct GlobalConnPool { +pub(crate) struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint @@ -276,7 +282,7 @@ pub struct GlobalConnPoolOptions { } impl GlobalConnPool { - pub fn new(config: &'static crate::config::HttpConfig) -> Arc { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { let shards = config.pool_options.pool_shards; Arc::new(Self { global_pool: DashMap::with_shard_amount(shards), @@ -287,21 +293,21 @@ impl GlobalConnPool { } #[cfg(test)] - pub fn get_global_connections_count(&self) -> usize { + pub(crate) fn get_global_connections_count(&self) -> usize { self.global_connections_count .load(atomic::Ordering::Relaxed) } - pub fn get_idle_timeout(&self) -> Duration { + pub(crate) fn get_idle_timeout(&self) -> Duration { self.config.pool_options.idle_timeout } - pub fn shutdown(&self) { + pub(crate) fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); } - pub async fn gc_worker(&self, mut rng: impl Rng) { + pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { let epoch = self.config.pool_options.gc_epoch; let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); loop { @@ -333,9 +339,9 @@ impl GlobalConnPool { } = pool.get_mut(); // ensure that closed clients are removed - pools.iter_mut().for_each(|(_, db_pool)| { + for db_pool in pools.values_mut() { clients_removed += db_pool.clear_closed_clients(total_conns); - }); + } // we only remove this pool if it has no active connections if *total_conns == 0 { @@ -375,7 +381,7 @@ impl GlobalConnPool { } } - pub fn get( + pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, @@ -390,7 +396,7 @@ impl GlobalConnPool { .write() .get_conn_entry(conn_info.db_and_user()) { - client = Some(entry.conn) + client = Some(entry.conn); } let endpoint_pool = Arc::downgrade(&endpoint_pool); @@ -399,21 +405,20 @@ impl GlobalConnPool { if client.is_closed() { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); - } else { - tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); - tracing::Span::current().record( - "pid", - tracing::field::display(client.inner.get_process_id()), - ); - info!( - cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), - "pool: reusing connection '{conn_info}'" - ); - client.session.send(ctx.session_id())?; - ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); - ctx.success(); - return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); + tracing::Span::current().record( + "pid", + tracing::field::display(client.inner.get_process_id()), + ); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "pool: reusing connection '{conn_info}'" + ); + client.session.send(ctx.session_id())?; + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } Ok(None) } @@ -463,7 +468,7 @@ impl GlobalConnPool { } } -pub fn poll_client( +pub(crate) fn poll_client( global_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, @@ -591,7 +596,7 @@ impl Drop for ClientInner { } } -pub trait ClientInnerExt: Sync + Send + 'static { +pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; } @@ -606,13 +611,13 @@ impl ClientInnerExt for tokio_postgres::Client { } impl ClientInner { - pub fn is_closed(&self) -> bool { + pub(crate) fn is_closed(&self) -> bool { self.inner.is_closed() } } impl Client { - pub fn metrics(&self) -> Arc { + pub(crate) fn metrics(&self) -> Arc { let aux = &self.inner.as_ref().unwrap().aux; USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, @@ -621,14 +626,14 @@ impl Client { } } -pub struct Client { +pub(crate) struct Client { span: Span, inner: Option>, conn_info: ConnInfo, pool: Weak>>, } -pub struct Discard<'a, C: ClientInnerExt> { +pub(crate) struct Discard<'a, C: ClientInnerExt> { conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } @@ -646,7 +651,7 @@ impl Client { pool, } } - pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) { + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, @@ -654,21 +659,21 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - (&mut inner.inner, Discard { pool, conn_info }) + (&mut inner.inner, Discard { conn_info, pool }) } } impl Discard<'_, C> { - pub fn check_idle(&mut self, status: ReadyForQueryStatus) { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle") + info!("pool: throwing away connection '{conn_info}' because connection is not idle"); } } - pub fn discard(&mut self) { + pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); } } } @@ -716,7 +721,9 @@ impl Drop for Client { mod tests { use std::{mem, sync::atomic::AtomicBool}; - use crate::{serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId}; + use crate::{ + proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId, + }; use super::*; @@ -758,6 +765,7 @@ mod tests { async fn test_pool() { let _ = env_logger::try_init(); let config = Box::leak(Box::new(crate::config::HttpConfig { + accept_websockets: false, pool_options: GlobalConnPoolOptions { max_conns_per_endpoint: 2, gc_epoch: Duration::from_secs(1), @@ -774,10 +782,10 @@ mod tests { user_info: ComputeUserInfo { user: "user".into(), endpoint: "endpoint".into(), - options: Default::default(), + options: NeonOptions::default(), }, dbname: "dbname".into(), - password: "password".as_bytes().into(), + auth: AuthData::Password("password".as_bytes().into()), }; let ep_pool = Arc::downgrade( &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()), @@ -832,10 +840,10 @@ mod tests { user_info: ComputeUserInfo { user: "user".into(), endpoint: "endpoint-2".into(), - options: Default::default(), + options: NeonOptions::default(), }, dbname: "dbname".into(), - password: "password".as_bytes().into(), + auth: AuthData::Password("password".as_bytes().into()), }; let ep_pool = Arc::downgrade( &pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()), diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index 701ab58f63..abf0ffe290 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -11,7 +11,7 @@ use serde::Serialize; use utils::http::error::ApiError; /// Like [`ApiError::into_response`] -pub fn api_error_into_response(this: ApiError) -> Response> { +pub(crate) fn api_error_into_response(this: ApiError) -> Response> { match this { ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( format!("{err:#?}"), // use debug printing so that we give the cause @@ -59,7 +59,7 @@ pub fn api_error_into_response(this: ApiError) -> Response> { /// Same as [`utils::http::error::HttpErrorBody`] #[derive(Serialize)] struct HttpErrorBody { - pub msg: String, + pub(crate) msg: String, } impl HttpErrorBody { @@ -80,7 +80,7 @@ impl HttpErrorBody { } /// Same as [`utils::http::json::json_response`] -pub fn json_response( +pub(crate) fn json_response( status: StatusCode, data: T, ) -> Result>, ApiError> { diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index c22c63e85b..9f328a0e1d 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -8,7 +8,7 @@ use tokio_postgres::Row; // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -pub fn json_to_pg_text(json: Vec) -> Vec> { +pub(crate) fn json_to_pg_text(json: Vec) -> Vec> { json.iter().map(json_value_to_pg_text).collect() } @@ -55,13 +55,13 @@ fn json_array_to_pg_array(value: &Value) -> Option { .collect::>() .join(","); - Some(format!("{{{}}}", vals)) + Some(format!("{{{vals}}}")) } } } #[derive(Debug, thiserror::Error)] -pub enum JsonConversionError { +pub(crate) enum JsonConversionError { #[error("internal error compute returned invalid data: {0}")] AsTextError(tokio_postgres::Error), #[error("parse int error: {0}")] @@ -77,7 +77,7 @@ pub enum JsonConversionError { // // Convert postgres row with text-encoded values to JSON object // -pub fn pg_text_row_to_json( +pub(crate) fn pg_text_row_to_json( row: &Row, columns: &[Type], raw_output: bool, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index c41df07a4d..5b36f5e91d 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -7,6 +7,7 @@ use futures::future::try_join; use futures::future::Either; use futures::StreamExt; use futures::TryFutureExt; +use http::header::AUTHORIZATION; use http_body_util::BodyExt; use http_body_util::Full; use hyper1::body::Body; @@ -56,6 +57,7 @@ use crate::DbName; use crate::RoleName; use super::backend::PoolingBackend; +use super::conn_pool::AuthData; use super::conn_pool::Client; use super::conn_pool::ConnInfo; use super::http_util::json_response; @@ -88,6 +90,7 @@ enum Payload { const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB +static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string"); static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); @@ -107,9 +110,9 @@ where } #[derive(Debug, thiserror::Error)] -pub enum ConnInfoError { +pub(crate) enum ConnInfoError { #[error("invalid header: {0}")] - InvalidHeader(&'static str), + InvalidHeader(&'static HeaderName), #[error("invalid connection string: {0}")] UrlParseError(#[from] url::ParseError), #[error("incorrect scheme")] @@ -147,16 +150,16 @@ impl UserFacingError for ConnInfoError { fn get_conn_info( ctx: &RequestMonitoring, headers: &HeaderMap, - tls: &TlsConfig, + tls: Option<&TlsConfig>, ) -> Result { // HTTP only uses cleartext (for now and likely always) ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let connection_string = headers - .get("Neon-Connection-String") - .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? + .get(&CONN_STRING) + .ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))? .to_str() - .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?; + .map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?; let connection_url = Url::parse(connection_string)?; @@ -179,17 +182,40 @@ fn get_conn_info( } ctx.set_user(username.clone()); - let password = connection_url - .password() - .ok_or(ConnInfoError::MissingPassword)?; - let password = urlencoding::decode_binary(password.as_bytes()); + let auth = if let Some(auth) = headers.get(&AUTHORIZATION) { + let auth = auth + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?; + AuthData::Jwt( + auth.strip_prefix("Bearer ") + .ok_or(ConnInfoError::MissingPassword)? + .into(), + ) + } else if let Some(pass) = connection_url.password() { + AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) { + std::borrow::Cow::Borrowed(b) => b.into(), + std::borrow::Cow::Owned(b) => b.into(), + }) + } else { + return Err(ConnInfoError::MissingPassword); + }; - let hostname = connection_url - .host_str() - .ok_or(ConnInfoError::MissingHostname)?; - - let endpoint = - endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?; + let endpoint = match connection_url.host() { + Some(url::Host::Domain(hostname)) => { + if let Some(tls) = tls { + endpoint_sni(hostname, &tls.common_names)? + .ok_or(ConnInfoError::MalformedEndpoint)? + } else { + hostname + .split_once('.') + .map_or(hostname, |(prefix, _)| prefix) + .into() + } + } + Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { + return Err(ConnInfoError::MissingHostname) + } + }; ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); @@ -215,15 +241,12 @@ fn get_conn_info( Ok(ConnInfo { user_info, dbname, - password: match password { - std::borrow::Cow::Borrowed(b) => b.into(), - std::borrow::Cow::Owned(b) => b.into(), - }, + auth, }) } // TODO: return different http error codes -pub async fn handle( +pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestMonitoring, request: Request, @@ -336,7 +359,7 @@ pub async fn handle( } #[derive(Debug, thiserror::Error)] -pub enum SqlOverHttpError { +pub(crate) enum SqlOverHttpError { #[error("{0}")] ReadPayload(#[from] ReadPayloadError), #[error("{0}")] @@ -390,7 +413,7 @@ impl UserFacingError for SqlOverHttpError { } #[derive(Debug, thiserror::Error)] -pub enum ReadPayloadError { +pub(crate) enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] Read(#[from] hyper1::Error), #[error("could not parse the HTTP request body: {0}")] @@ -407,7 +430,7 @@ impl ReportableError for ReadPayloadError { } #[derive(Debug, thiserror::Error)] -pub enum SqlOverHttpCancel { +pub(crate) enum SqlOverHttpCancel { #[error("query was cancelled")] Postgres, #[error("query was cancelled while stuck trying to connect to the database")] @@ -502,7 +525,7 @@ async fn handle_inner( let headers = request.headers(); // TLS config should be there. - let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?; + let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?; info!(user = conn_info.user_info.user.as_str(), "credentials"); // Allow connection pooling only if explicitly requested @@ -540,9 +563,24 @@ async fn handle_inner( let authenticate_and_connect = Box::pin( async { - let keys = backend - .authenticate(ctx, &config.authentication_config, &conn_info) - .await?; + let keys = match &conn_info.auth { + AuthData::Password(pw) => { + backend + .authenticate_with_password( + ctx, + &config.authentication_config, + &conn_info.user_info, + pw, + ) + .await? + } + AuthData::Jwt(jwt) => { + backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await? + } + }; + let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 4fba4d141c..3d257223b8 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -27,7 +27,7 @@ use tracing::warn; pin_project! { /// This is a wrapper around a [`WebSocketStream`] that /// implements [`AsyncRead`] and [`AsyncWrite`]. - pub struct WebSocketRw { + pub(crate) struct WebSocketRw { #[pin] stream: WebSocketServer, recv: Bytes, @@ -36,7 +36,7 @@ pin_project! { } impl WebSocketRw { - pub fn new(stream: WebSocketServer) -> Self { + pub(crate) fn new(stream: WebSocketServer) -> Self { Self { stream, recv: Bytes::new(), @@ -127,7 +127,7 @@ impl AsyncBufRead for WebSocketRw { } } -pub async fn serve_websocket( +pub(crate) async fn serve_websocket( config: &'static ProxyConfig, ctx: RequestMonitoring, websocket: OnUpgrade, diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 690e92ffb1..332dc27787 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -35,7 +35,7 @@ impl PqStream { } /// Get a shared reference to the underlying stream. - pub fn get_ref(&self) -> &S { + pub(crate) fn get_ref(&self) -> &S { self.framed.get_ref() } } @@ -62,12 +62,12 @@ impl PqStream { .ok_or_else(err_connection) } - pub async fn read_password_message(&mut self) -> io::Result { + pub(crate) async fn read_password_message(&mut self) -> io::Result { match self.read_message().await? { FeMessage::PasswordMessage(msg) => Ok(msg), bad => Err(io::Error::new( io::ErrorKind::InvalidData, - format!("unexpected message type: {:?}", bad), + format!("unexpected message type: {bad:?}"), )), } } @@ -99,7 +99,10 @@ impl ReportableError for ReportedError { impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. - pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { + pub(crate) fn write_message_noflush( + &mut self, + message: &BeMessage<'_>, + ) -> io::Result<&mut Self> { self.framed .write_message(message) .map_err(ProtocolError::into_io_error)?; @@ -114,7 +117,7 @@ impl PqStream { } /// Flush the output buffer into the underlying stream. - pub async fn flush(&mut self) -> io::Result<&mut Self> { + pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { self.framed.flush().await?; Ok(self) } @@ -146,7 +149,7 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Trait [`UserFacingError`] acts as an allowlist for error types. - pub async fn throw_error(&mut self, error: E) -> Result + pub(crate) async fn throw_error(&mut self, error: E) -> Result where E: UserFacingError + Into, { @@ -200,7 +203,7 @@ impl Stream { } } - pub fn tls_server_end_point(&self) -> TlsServerEndPoint { + pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint { match self { Stream::Raw { .. } => TlsServerEndPoint::Undefined, Stream::Tls { @@ -234,7 +237,7 @@ impl Stream { .await .inspect_err(|_| { if record_handshake_error { - Metrics::get().proxy.tls_handshake_failures.inc() + Metrics::get().proxy.tls_handshake_failures.inc(); } })?), Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls), diff --git a/proxy/src/url.rs b/proxy/src/url.rs index 92c64bb8ad..28ac7efdfc 100644 --- a/proxy/src/url.rs +++ b/proxy/src/url.rs @@ -7,12 +7,12 @@ pub struct ApiUrl(url::Url); impl ApiUrl { /// Consume the wrapper and return inner [url](url::Url). - pub fn into_inner(self) -> url::Url { + pub(crate) fn into_inner(self) -> url::Url { self.0 } /// See [`url::Url::path_segments_mut`]. - pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut { + pub(crate) fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> { // We've already verified that it works during construction. self.0.path_segments_mut().expect("bad API url") } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index a8735fe0bb..aa8c7ba319 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -43,12 +43,12 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// so while the project-id is unique across regions the whole pipeline will work correctly /// because we enrich the event with project_id in the control-plane endpoint. #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] -pub struct Ids { - pub endpoint_id: EndpointIdInt, - pub branch_id: BranchIdInt, +pub(crate) struct Ids { + pub(crate) endpoint_id: EndpointIdInt, + pub(crate) branch_id: BranchIdInt, } -pub trait MetricCounterRecorder { +pub(crate) trait MetricCounterRecorder { /// Record that some bytes were sent from the proxy to the client fn record_egress(&self, bytes: u64); /// Record that some connections were opened @@ -92,7 +92,7 @@ impl MetricCounterReporter for MetricBackupCounter { } #[derive(Debug)] -pub struct MetricCounter { +pub(crate) struct MetricCounter { transmitted: AtomicU64, opened_connections: AtomicUsize, backup: Arc, @@ -173,14 +173,14 @@ impl Clearable for C { type FastHasher = std::hash::BuildHasherDefault; #[derive(Default)] -pub struct Metrics { +pub(crate) struct Metrics { endpoints: DashMap, FastHasher>, backup_endpoints: DashMap, FastHasher>, } impl Metrics { /// Register a new byte metrics counter for this endpoint - pub fn register(&self, ids: Ids) -> Arc { + pub(crate) fn register(&self, ids: Ids) -> Arc { let backup = if let Some(entry) = self.backup_endpoints.get(&ids) { entry.clone() } else { @@ -215,7 +215,7 @@ impl Metrics { } } -pub static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); +pub(crate) static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result { info!("metrics collector config: {config:?}"); @@ -450,12 +450,9 @@ async fn upload_events_chunk( remote_path: &RemotePath, cancel: &CancellationToken, ) -> anyhow::Result<()> { - let storage = match storage { - Some(storage) => storage, - None => { - error!("no remote storage configured"); - return Ok(()); - } + let Some(storage) = storage else { + error!("no remote storage configured"); + return Ok(()); }; let data = serde_json::to_vec(&chunk).context("serialize metrics")?; let mut encoder = GzipEncoder::new(Vec::new()); diff --git a/proxy/src/waiters.rs b/proxy/src/waiters.rs index 888ad38048..86d0f9e8b2 100644 --- a/proxy/src/waiters.rs +++ b/proxy/src/waiters.rs @@ -7,13 +7,13 @@ use thiserror::Error; use tokio::sync::oneshot; #[derive(Debug, Error)] -pub enum RegisterError { +pub(crate) enum RegisterError { #[error("Waiter `{0}` already registered")] Occupied(String), } #[derive(Debug, Error)] -pub enum NotifyError { +pub(crate) enum NotifyError { #[error("Notify failed: waiter `{0}` not registered")] NotFound(String), @@ -22,21 +22,21 @@ pub enum NotifyError { } #[derive(Debug, Error)] -pub enum WaitError { +pub(crate) enum WaitError { #[error("Wait failed: channel hangup")] Hangup, } -pub struct Waiters(pub(self) Mutex>>); +pub(crate) struct Waiters(pub(self) Mutex>>); impl Default for Waiters { fn default() -> Self { - Waiters(Default::default()) + Waiters(Mutex::default()) } } impl Waiters { - pub fn register(&self, key: String) -> Result, RegisterError> { + pub(crate) fn register(&self, key: String) -> Result, RegisterError> { let (tx, rx) = oneshot::channel(); self.0 @@ -53,7 +53,7 @@ impl Waiters { }) } - pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError> + pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError> where T: Send + Sync, { @@ -79,7 +79,7 @@ impl<'a, T> Drop for DropKey<'a, T> { } pin_project! { - pub struct Waiter<'a, T> { + pub(crate) struct Waiter<'a, T> { #[pin] receiver: oneshot::Receiver, guard: DropKey<'a, T>, diff --git a/safekeeper/src/http/openapi_spec.yaml b/safekeeper/src/http/openapi_spec.yaml index a617e0310c..70999853c2 100644 --- a/safekeeper/src/http/openapi_spec.yaml +++ b/safekeeper/src/http/openapi_spec.yaml @@ -86,42 +86,6 @@ paths: default: $ref: "#/components/responses/GenericError" - /v1/tenant/{tenant_id}/timeline/{source_timeline_id}/copy: - parameters: - - name: tenant_id - in: path - required: true - schema: - type: string - format: hex - - name: source_timeline_id - in: path - required: true - schema: - type: string - format: hex - - post: - tags: - - "Timeline" - summary: Register new timeline as copy of existing timeline - description: "" - operationId: v1CopyTenantTimeline - requestBody: - content: - application/json: - schema: - $ref: "#/components/schemas/TimelineCopyRequest" - responses: - "201": - description: Timeline created - # TODO: return timeline info? - "403": - $ref: "#/components/responses/ForbiddenError" - default: - $ref: "#/components/responses/GenericError" - - /v1/tenant/{tenant_id}/timeline/{timeline_id}: parameters: - name: tenant_id @@ -179,6 +143,40 @@ paths: default: $ref: "#/components/responses/GenericError" + /v1/tenant/{tenant_id}/timeline/{source_timeline_id}/copy: + parameters: + - name: tenant_id + in: path + required: true + schema: + type: string + format: hex + - name: source_timeline_id + in: path + required: true + schema: + type: string + format: hex + + post: + tags: + - "Timeline" + summary: Register new timeline as copy of existing timeline + description: "" + operationId: v1CopyTenantTimeline + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/TimelineCopyRequest" + responses: + "201": + description: Timeline created + # TODO: return timeline info? + "403": + $ref: "#/components/responses/ForbiddenError" + default: + $ref: "#/components/responses/GenericError" /v1/record_safekeeper_info/{tenant_id}/{timeline_id}: parameters: diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index c9defb0bcf..91ffa95c21 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -114,6 +114,64 @@ fn check_permission(request: &Request, tenant_id: Option) -> Res }) } +/// Deactivates all timelines for the tenant and removes its data directory. +/// See `timeline_delete_handler`. +async fn tenant_delete_handler(mut request: Request) -> Result, ApiError> { + let tenant_id = parse_request_param(&request, "tenant_id")?; + let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); + check_permission(&request, Some(tenant_id))?; + ensure_no_body(&mut request).await?; + // FIXME: `delete_force_all_for_tenant` can return an error for multiple different reasons; + // Using an `InternalServerError` should be fixed when the types support it + let delete_info = GlobalTimelines::delete_force_all_for_tenant(&tenant_id, only_local) + .await + .map_err(ApiError::InternalServerError)?; + json_response( + StatusCode::OK, + delete_info + .iter() + .map(|(ttid, resp)| (format!("{}", ttid.timeline_id), *resp)) + .collect::>(), + ) +} + +async fn timeline_create_handler(mut request: Request) -> Result, ApiError> { + let request_data: TimelineCreateRequest = json_request(&mut request).await?; + + let ttid = TenantTimelineId { + tenant_id: request_data.tenant_id, + timeline_id: request_data.timeline_id, + }; + check_permission(&request, Some(ttid.tenant_id))?; + + let server_info = ServerInfo { + pg_version: request_data.pg_version, + system_id: request_data.system_id.unwrap_or(0), + wal_seg_size: request_data.wal_seg_size.unwrap_or(WAL_SEGMENT_SIZE as u32), + }; + let local_start_lsn = request_data.local_start_lsn.unwrap_or_else(|| { + request_data + .commit_lsn + .segment_lsn(server_info.wal_seg_size as usize) + }); + GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn) + .await + .map_err(ApiError::InternalServerError)?; + + json_response(StatusCode::OK, ()) +} + +/// List all (not deleted) timelines. +/// Note: it is possible to do the same with debug_dump. +async fn timeline_list_handler(request: Request) -> Result, ApiError> { + check_permission(&request, None)?; + let res: Vec = GlobalTimelines::get_all() + .iter() + .map(|tli| tli.ttid) + .collect(); + json_response(StatusCode::OK, res) +} + /// Report info about timeline. async fn timeline_status_handler(request: Request) -> Result, ApiError> { let ttid = TenantTimelineId::new( @@ -164,30 +222,21 @@ async fn timeline_status_handler(request: Request) -> Result) -> Result, ApiError> { - let request_data: TimelineCreateRequest = json_request(&mut request).await?; - - let ttid = TenantTimelineId { - tenant_id: request_data.tenant_id, - timeline_id: request_data.timeline_id, - }; +/// Deactivates the timeline and removes its data directory. +async fn timeline_delete_handler(mut request: Request) -> Result, ApiError> { + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); check_permission(&request, Some(ttid.tenant_id))?; - - let server_info = ServerInfo { - pg_version: request_data.pg_version, - system_id: request_data.system_id.unwrap_or(0), - wal_seg_size: request_data.wal_seg_size.unwrap_or(WAL_SEGMENT_SIZE as u32), - }; - let local_start_lsn = request_data.local_start_lsn.unwrap_or_else(|| { - request_data - .commit_lsn - .segment_lsn(server_info.wal_seg_size as usize) - }); - GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn) + ensure_no_body(&mut request).await?; + // FIXME: `delete_force` can fail from both internal errors and bad requests. Add better + // error handling here when we're able to. + let resp = GlobalTimelines::delete(&ttid, only_local) .await .map_err(ApiError::InternalServerError)?; - - json_response(StatusCode::OK, ()) + json_response(StatusCode::OK, resp) } /// Pull timeline from peer safekeeper instances. @@ -269,6 +318,46 @@ async fn timeline_copy_handler(mut request: Request) -> Result, +) -> Result, ApiError> { + check_permission(&request, None)?; + + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + + let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?; + + let patch_request: patch_control_file::Request = json_request(&mut request).await?; + let response = patch_control_file::handle_request(tli, patch_request) + .await + .map_err(ApiError::InternalServerError)?; + + json_response(StatusCode::OK, response) +} + +/// Force persist control file. +async fn timeline_checkpoint_handler(request: Request) -> Result, ApiError> { + check_permission(&request, None)?; + + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + + let tli = GlobalTimelines::get(ttid)?; + tli.write_shared_state() + .await + .sk + .state_mut() + .flush() + .await + .map_err(ApiError::InternalServerError)?; + json_response(StatusCode::OK, ()) +} + async fn timeline_digest_handler(request: Request) -> Result, ApiError> { let ttid = TenantTimelineId::new( parse_request_param(&request, "tenant_id")?, @@ -300,64 +389,6 @@ async fn timeline_digest_handler(request: Request) -> Result) -> Result, ApiError> { - check_permission(&request, None)?; - - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - - let tli = GlobalTimelines::get(ttid)?; - tli.write_shared_state() - .await - .sk - .state_mut() - .flush() - .await - .map_err(ApiError::InternalServerError)?; - json_response(StatusCode::OK, ()) -} - -/// Deactivates the timeline and removes its data directory. -async fn timeline_delete_handler(mut request: Request) -> Result, ApiError> { - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); - check_permission(&request, Some(ttid.tenant_id))?; - ensure_no_body(&mut request).await?; - // FIXME: `delete_force` can fail from both internal errors and bad requests. Add better - // error handling here when we're able to. - let resp = GlobalTimelines::delete(&ttid, only_local) - .await - .map_err(ApiError::InternalServerError)?; - json_response(StatusCode::OK, resp) -} - -/// Deactivates all timelines for the tenant and removes its data directory. -/// See `timeline_delete_handler`. -async fn tenant_delete_handler(mut request: Request) -> Result, ApiError> { - let tenant_id = parse_request_param(&request, "tenant_id")?; - let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); - check_permission(&request, Some(tenant_id))?; - ensure_no_body(&mut request).await?; - // FIXME: `delete_force_all_for_tenant` can return an error for multiple different reasons; - // Using an `InternalServerError` should be fixed when the types support it - let delete_info = GlobalTimelines::delete_force_all_for_tenant(&tenant_id, only_local) - .await - .map_err(ApiError::InternalServerError)?; - json_response( - StatusCode::OK, - delete_info - .iter() - .map(|(ttid, resp)| (format!("{}", ttid.timeline_id), *resp)) - .collect::>(), - ) -} - /// Used only in tests to hand craft required data. async fn record_safekeeper_info(mut request: Request) -> Result, ApiError> { let ttid = TenantTimelineId::new( @@ -499,26 +530,6 @@ async fn dump_debug_handler(mut request: Request) -> Result Ok(response) } -async fn patch_control_file_handler( - mut request: Request, -) -> Result, ApiError> { - check_permission(&request, None)?; - - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - - let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?; - - let patch_request: patch_control_file::Request = json_request(&mut request).await?; - let response = patch_control_file::handle_request(tli, patch_request) - .await - .map_err(ApiError::InternalServerError)?; - - json_response(StatusCode::OK, response) -} - /// Safekeeper http router. pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder { let mut router = endpoint::make_router(); @@ -558,26 +569,29 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder failpoints_handler(r, cancel).await }) }) + .delete("/v1/tenant/:tenant_id", |r| { + request_span(r, tenant_delete_handler) + }) // Will be used in the future instead of implicit timeline creation .post("/v1/tenant/timeline", |r| { request_span(r, timeline_create_handler) }) + .get("/v1/tenant/timeline", |r| { + request_span(r, timeline_list_handler) + }) .get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| { request_span(r, timeline_status_handler) }) .delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| { request_span(r, timeline_delete_handler) }) - .delete("/v1/tenant/:tenant_id", |r| { - request_span(r, tenant_delete_handler) + .post("/v1/pull_timeline", |r| { + request_span(r, timeline_pull_handler) }) .get( "/v1/tenant/:tenant_id/timeline/:timeline_id/snapshot/:destination_id", |r| request_span(r, timeline_snapshot_handler), ) - .post("/v1/pull_timeline", |r| { - request_span(r, timeline_pull_handler) - }) .post( "/v1/tenant/:tenant_id/timeline/:source_timeline_id/copy", |r| request_span(r, timeline_copy_handler), @@ -590,14 +604,13 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder "/v1/tenant/:tenant_id/timeline/:timeline_id/checkpoint", |r| request_span(r, timeline_checkpoint_handler), ) - // for tests + .get("/v1/tenant/:tenant_id/timeline/:timeline_id/digest", |r| { + request_span(r, timeline_digest_handler) + }) .post("/v1/record_safekeeper_info/:tenant_id/:timeline_id", |r| { request_span(r, record_safekeeper_info) }) .get("/v1/debug_dump", |r| request_span(r, dump_debug_handler)) - .get("/v1/tenant/:tenant_id/timeline/:timeline_id/digest", |r| { - request_span(r, timeline_digest_handler) - }) } #[cfg(test)] diff --git a/scripts/comment-test-report.js b/scripts/comment-test-report.js index f42262cf48..e8e0b3c23a 100755 --- a/scripts/comment-test-report.js +++ b/scripts/comment-test-report.js @@ -68,16 +68,29 @@ const parseReportJson = async ({ reportJsonUrl, fetch }) => { console.info(`Cannot get BUILD_TYPE and Postgres Version from test name: "${test.name}", defaulting to "release" and "14"`) buildType = "release" - pgVersion = "14" + pgVersion = "16" } pgVersions.add(pgVersion) + // We use `arch` as it is returned by GitHub Actions + // (RUNNER_ARCH env var): X86, X64, ARM, or ARM64 + // Ref https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables + let arch = "" + if (test.parameters.includes("'X64'")) { + arch = "x86-64" + } else if (test.parameters.includes("'ARM64'")) { + arch = "arm64" + } else { + arch = "unknown" + } + // Removing build type and PostgreSQL version from the test name to make it shorter const testName = test.name.replace(new RegExp(`${buildType}-pg${pgVersion}-?`), "").replace("[]", "") test.pytestName = `${parentSuite.name.replace(".", "/")}/${suite.name}.py::${testName}` test.pgVersion = pgVersion test.buildType = buildType + test.arch = arch if (test.status === "passed") { passedTests[pgVersion][testName].push(test) @@ -144,7 +157,7 @@ const reportSummary = async (params) => { const links = [] for (const test of tests) { const allureLink = `${reportUrl}#suites/${test.parentUid}/${test.uid}` - links.push(`[${test.buildType}](${allureLink})`) + links.push(`[${test.buildType}-${test.arch}](${allureLink})`) } summary += `- \`${testName}\`: ${links.join(", ")}\n` } @@ -175,7 +188,7 @@ const reportSummary = async (params) => { const links = [] for (const test of tests) { const allureLink = `${reportUrl}#suites/${test.parentUid}/${test.uid}/retries` - links.push(`[${test.buildType}](${allureLink})`) + links.push(`[${test.buildType}-${test.arch}](${allureLink})`) } summary += `- \`${testName}\`: ${links.join(", ")}\n` } diff --git a/scripts/ingest_regress_test_result-new-format.py b/scripts/ingest_regress_test_result-new-format.py index cff1d9875f..40d7254e00 100644 --- a/scripts/ingest_regress_test_result-new-format.py +++ b/scripts/ingest_regress_test_result-new-format.py @@ -18,6 +18,7 @@ import psycopg2 from psycopg2.extras import execute_values CREATE_TABLE = """ +CREATE TYPE arch AS ENUM ('ARM64', 'X64', 'UNKNOWN'); CREATE TABLE IF NOT EXISTS results ( id BIGSERIAL PRIMARY KEY, parent_suite TEXT NOT NULL, @@ -28,6 +29,7 @@ CREATE TABLE IF NOT EXISTS results ( stopped_at TIMESTAMPTZ NOT NULL, duration INT NOT NULL, flaky BOOLEAN NOT NULL, + arch arch DEFAULT 'X64', build_type TEXT NOT NULL, pg_version INT NOT NULL, run_id BIGINT NOT NULL, @@ -35,7 +37,7 @@ CREATE TABLE IF NOT EXISTS results ( reference TEXT NOT NULL, revision CHAR(40) NOT NULL, raw JSONB COMPRESSION lz4 NOT NULL, - UNIQUE (parent_suite, suite, name, build_type, pg_version, started_at, stopped_at, run_id) + UNIQUE (parent_suite, suite, name, arch, build_type, pg_version, started_at, stopped_at, run_id) ); """ @@ -50,6 +52,7 @@ class Row: stopped_at: datetime duration: int flaky: bool + arch: str build_type: str pg_version: int run_id: int @@ -121,6 +124,14 @@ def ingest_test_result( raw.pop("labels") raw.pop("extra") + # All allure parameters are prefixed with "__", see test_runner/fixtures/parametrize.py + parameters = { + p["name"].removeprefix("__"): p["value"] + for p in test["parameters"] + if p["name"].startswith("__") + } + arch = parameters.get("arch", "UNKNOWN").strip("'") + build_type, pg_version, unparametrized_name = parse_test_name(test["name"]) labels = {label["name"]: label["value"] for label in test["labels"]} row = Row( @@ -132,6 +143,7 @@ def ingest_test_result( stopped_at=datetime.fromtimestamp(test["time"]["stop"] / 1000, tz=timezone.utc), duration=test["time"]["duration"], flaky=test["flaky"] or test["retriesStatusChange"], + arch=arch, build_type=build_type, pg_version=pg_version, run_id=run_id, diff --git a/scripts/ps_ec2_setup_instance_store b/scripts/ps_ec2_setup_instance_store index 1f88f252eb..7c383e322f 100755 --- a/scripts/ps_ec2_setup_instance_store +++ b/scripts/ps_ec2_setup_instance_store @@ -44,7 +44,7 @@ run the following commands from the top of the neon.git checkout # test suite run export TEST_OUTPUT="$TEST_OUTPUT" - DEFAULT_PG_VERSION=15 BUILD_TYPE=release ./scripts/pytest test_runner/performance/test_latency.py + DEFAULT_PG_VERSION=16 BUILD_TYPE=release ./scripts/pytest test_runner/performance/test_latency.py # for interactive use export NEON_REPO_DIR="$NEON_REPO_DIR" diff --git a/storage_controller/migrations/2024-08-23-170149_tenant_id_index/down.sql b/storage_controller/migrations/2024-08-23-170149_tenant_id_index/down.sql new file mode 100644 index 0000000000..518c747100 --- /dev/null +++ b/storage_controller/migrations/2024-08-23-170149_tenant_id_index/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP INDEX tenant_shards_tenant_id; \ No newline at end of file diff --git a/storage_controller/migrations/2024-08-23-170149_tenant_id_index/up.sql b/storage_controller/migrations/2024-08-23-170149_tenant_id_index/up.sql new file mode 100644 index 0000000000..dd6b37781a --- /dev/null +++ b/storage_controller/migrations/2024-08-23-170149_tenant_id_index/up.sql @@ -0,0 +1,2 @@ +-- Your SQL goes here +CREATE INDEX tenant_shards_tenant_id ON tenant_shards (tenant_id); \ No newline at end of file diff --git a/storage_controller/src/heartbeater.rs b/storage_controller/src/heartbeater.rs index 1bb9c17f30..b7e66d33eb 100644 --- a/storage_controller/src/heartbeater.rs +++ b/storage_controller/src/heartbeater.rs @@ -6,10 +6,7 @@ use std::{ }; use tokio_util::sync::CancellationToken; -use pageserver_api::{ - controller_api::{NodeAvailability, UtilizationScore}, - models::PageserverUtilization, -}; +use pageserver_api::{controller_api::NodeAvailability, models::PageserverUtilization}; use thiserror::Error; use utils::id::NodeId; @@ -87,9 +84,12 @@ impl Heartbeater { pageservers, reply: sender, }) - .unwrap(); + .map_err(|_| HeartbeaterError::Cancel)?; - receiver.await.unwrap() + receiver + .await + .map_err(|_| HeartbeaterError::Cancel) + .and_then(|x| x) } } @@ -144,7 +144,8 @@ impl HeartbeaterTask { // goes through to the pageserver even when the node is marked offline. // This doesn't impact the availability observed by [`crate::service::Service`]. let mut node_clone = node.clone(); - node_clone.set_availability(NodeAvailability::Active(UtilizationScore::worst())); + node_clone + .set_availability(NodeAvailability::Active(PageserverUtilization::full())); async move { let response = node_clone @@ -176,7 +177,7 @@ impl HeartbeaterTask { node.get_availability() { PageserverState::WarmingUp { - started_at: last_seen_at, + started_at: *last_seen_at, } } else { PageserverState::Offline diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 7bbd1541cf..207bd5a1e6 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -1074,7 +1074,6 @@ pub fn make_router( RequestName("control_v1_metadata_health_list_outdated"), ) }) - // TODO(vlad): endpoint for cancelling drain and fill // Tenant Shard operations .put("/control/v1/tenant/:tenant_shard_id/migrate", |r| { tenant_service_handler( diff --git a/storage_controller/src/leadership.rs b/storage_controller/src/leadership.rs new file mode 100644 index 0000000000..5fae8991ec --- /dev/null +++ b/storage_controller/src/leadership.rs @@ -0,0 +1,135 @@ +use std::sync::Arc; + +use hyper::Uri; +use tokio_util::sync::CancellationToken; + +use crate::{ + peer_client::{GlobalObservedState, PeerClient}, + persistence::{ControllerPersistence, DatabaseError, DatabaseResult, Persistence}, + service::Config, +}; + +/// Helper for storage controller leadership acquisition +pub(crate) struct Leadership { + persistence: Arc, + config: Config, + cancel: CancellationToken, +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum Error { + #[error(transparent)] + Database(#[from] DatabaseError), +} + +pub(crate) type Result = std::result::Result; + +impl Leadership { + pub(crate) fn new( + persistence: Arc, + config: Config, + cancel: CancellationToken, + ) -> Self { + Self { + persistence, + config, + cancel, + } + } + + /// Find the current leader in the database and request it to step down if required. + /// Should be called early on in within the start-up sequence. + /// + /// Returns a tuple of two optionals: the current leader and its observed state + pub(crate) async fn step_down_current_leader( + &self, + ) -> Result<(Option, Option)> { + let leader = self.current_leader().await?; + let leader_step_down_state = if let Some(ref leader) = leader { + if self.config.start_as_candidate { + self.request_step_down(leader).await + } else { + None + } + } else { + tracing::info!("No leader found to request step down from. Will build observed state."); + None + }; + + Ok((leader, leader_step_down_state)) + } + + /// Mark the current storage controller instance as the leader in the database + pub(crate) async fn become_leader( + &self, + current_leader: Option, + ) -> Result<()> { + if let Some(address_for_peers) = &self.config.address_for_peers { + // TODO: `address-for-peers` can become a mandatory cli arg + // after we update the k8s setup + let proposed_leader = ControllerPersistence { + address: address_for_peers.to_string(), + started_at: chrono::Utc::now(), + }; + + self.persistence + .update_leader(current_leader, proposed_leader) + .await + .map_err(Error::Database) + } else { + tracing::info!("No address-for-peers provided. Skipping leader persistence."); + Ok(()) + } + } + + async fn current_leader(&self) -> DatabaseResult> { + let res = self.persistence.get_leader().await; + if let Err(DatabaseError::Query(diesel::result::Error::DatabaseError(_kind, ref err))) = res + { + const REL_NOT_FOUND_MSG: &str = "relation \"controllers\" does not exist"; + if err.message().trim() == REL_NOT_FOUND_MSG { + // Special case: if this is a brand new storage controller, migrations will not + // have run at this point yet, and, hence, the controllers table does not exist. + // Detect this case via the error string (diesel doesn't type it) and allow it. + tracing::info!("Detected first storage controller start-up. Allowing missing controllers table ..."); + return Ok(None); + } + } + + res + } + + /// Request step down from the currently registered leader in the database + /// + /// If such an entry is persisted, the success path returns the observed + /// state and details of the leader. Otherwise, None is returned indicating + /// there is no leader currently. + async fn request_step_down( + &self, + leader: &ControllerPersistence, + ) -> Option { + tracing::info!("Sending step down request to {leader:?}"); + + let client = PeerClient::new( + Uri::try_from(leader.address.as_str()).expect("Failed to build leader URI"), + self.config.peer_jwt_token.clone(), + ); + let state = client.step_down(&self.cancel).await; + match state { + Ok(state) => Some(state), + Err(err) => { + // TODO: Make leaders periodically update a timestamp field in the + // database and, if the leader is not reachable from the current instance, + // but inferred as alive from the timestamp, abort start-up. This avoids + // a potential scenario in which we have two controllers acting as leaders. + tracing::error!( + "Leader ({}) did not respond to step-down request: {}", + leader.address, + err + ); + + None + } + } + } +} diff --git a/storage_controller/src/lib.rs b/storage_controller/src/lib.rs index 2034addbe1..60e613bb5c 100644 --- a/storage_controller/src/lib.rs +++ b/storage_controller/src/lib.rs @@ -8,6 +8,7 @@ mod drain_utils; mod heartbeater; pub mod http; mod id_lock_map; +mod leadership; pub mod metrics; mod node; mod pageserver_client; diff --git a/storage_controller/src/main.rs b/storage_controller/src/main.rs index 7387d36690..e3f29b84e7 100644 --- a/storage_controller/src/main.rs +++ b/storage_controller/src/main.rs @@ -1,6 +1,5 @@ use anyhow::{anyhow, Context}; use clap::Parser; -use diesel::Connection; use hyper::Uri; use metrics::launch_timestamp::LaunchTimestamp; use metrics::BuildInfo; @@ -27,9 +26,6 @@ use utils::{project_build_tag, project_git_version, tcp_listener}; project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); -use diesel_migrations::{embed_migrations, EmbeddedMigrations}; -pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); - #[derive(Parser)] #[command(author, version, about, long_about = None)] #[command(arg_required_else_help(true))] @@ -51,6 +47,9 @@ struct Cli { #[arg(long)] control_plane_jwt_token: Option, + #[arg(long)] + peer_jwt_token: Option, + /// URL to control plane compute notification endpoint #[arg(long)] compute_hook_url: Option, @@ -130,28 +129,28 @@ struct Secrets { public_key: Option, jwt_token: Option, control_plane_jwt_token: Option, + peer_jwt_token: Option, } impl Secrets { const DATABASE_URL_ENV: &'static str = "DATABASE_URL"; const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN"; const CONTROL_PLANE_JWT_TOKEN_ENV: &'static str = "CONTROL_PLANE_JWT_TOKEN"; + const PEER_JWT_TOKEN_ENV: &'static str = "PEER_JWT_TOKEN"; const PUBLIC_KEY_ENV: &'static str = "PUBLIC_KEY"; /// Load secrets from, in order of preference: /// - CLI args if database URL is provided on the CLI /// - Environment variables if DATABASE_URL is set. - /// - AWS Secrets Manager secrets async fn load(args: &Cli) -> anyhow::Result { - let Some(database_url) = - Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV).await + let Some(database_url) = Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV) else { anyhow::bail!( "Database URL is not set (set `--database-url`, or `DATABASE_URL` environment)" ) }; - let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV).await { + let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV) { Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?), None => None, }; @@ -159,18 +158,18 @@ impl Secrets { let this = Self { database_url, public_key, - jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV).await, + jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV), control_plane_jwt_token: Self::load_secret( &args.control_plane_jwt_token, Self::CONTROL_PLANE_JWT_TOKEN_ENV, - ) - .await, + ), + peer_jwt_token: Self::load_secret(&args.peer_jwt_token, Self::PEER_JWT_TOKEN_ENV), }; Ok(this) } - async fn load_secret(cli: &Option, env_name: &str) -> Option { + fn load_secret(cli: &Option, env_name: &str) -> Option { if let Some(v) = cli { Some(v.clone()) } else if let Ok(v) = std::env::var(env_name) { @@ -181,20 +180,6 @@ impl Secrets { } } -/// Execute the diesel migrations that are built into this binary -async fn migration_run(database_url: &str) -> anyhow::Result<()> { - use diesel::PgConnection; - use diesel_migrations::{HarnessWithOutput, MigrationHarness}; - let mut conn = PgConnection::establish(database_url)?; - - HarnessWithOutput::write_to_stdout(&mut conn) - .run_pending_migrations(MIGRATIONS) - .map(|_| ()) - .map_err(|e| anyhow::anyhow!(e))?; - - Ok(()) -} - fn main() -> anyhow::Result<()> { logging::init( LogFormat::Plain, @@ -284,6 +269,7 @@ async fn async_main() -> anyhow::Result<()> { let config = Config { jwt_token: secrets.jwt_token, control_plane_jwt_token: secrets.control_plane_jwt_token, + peer_jwt_token: secrets.peer_jwt_token, compute_hook_url: args.compute_hook_url, max_offline_interval: args .max_offline_interval @@ -304,13 +290,9 @@ async fn async_main() -> anyhow::Result<()> { http_service_port: args.listen.port() as i32, }; - // After loading secrets & config, but before starting anything else, apply database migrations + // Validate that we can connect to the database Persistence::await_connection(&secrets.database_url, args.db_connect_timeout.into()).await?; - migration_run(&secrets.database_url) - .await - .context("Running database migrations")?; - let persistence = Arc::new(Persistence::new(secrets.database_url)); let service = Service::spawn(config, persistence.clone()).await?; diff --git a/storage_controller/src/metrics.rs b/storage_controller/src/metrics.rs index c2303e7a7f..5cfcfb4b1f 100644 --- a/storage_controller/src/metrics.rs +++ b/storage_controller/src/metrics.rs @@ -230,6 +230,7 @@ pub(crate) enum DatabaseErrorLabel { Connection, ConnectionPool, Logical, + Migration, } impl DatabaseError { @@ -239,6 +240,7 @@ impl DatabaseError { Self::Connection(_) => DatabaseErrorLabel::Connection, Self::ConnectionPool(_) => DatabaseErrorLabel::ConnectionPool, Self::Logical(_) => DatabaseErrorLabel::Logical, + Self::Migration(_) => DatabaseErrorLabel::Migration, } } } diff --git a/storage_controller/src/node.rs b/storage_controller/src/node.rs index ea765ca123..61a44daca9 100644 --- a/storage_controller/src/node.rs +++ b/storage_controller/src/node.rs @@ -92,15 +92,15 @@ impl Node { } } - pub(crate) fn get_availability(&self) -> NodeAvailability { - self.availability + pub(crate) fn get_availability(&self) -> &NodeAvailability { + &self.availability } pub(crate) fn set_availability(&mut self, availability: NodeAvailability) { use AvailabilityTransition::*; use NodeAvailability::WarmingUp; - match self.get_availability_transition(availability) { + match self.get_availability_transition(&availability) { ToActive => { // Give the node a new cancellation token, effectively resetting it to un-cancelled. Any // users of previously-cloned copies of the node will still see the old cancellation @@ -115,8 +115,8 @@ impl Node { Unchanged | ToWarmingUpFromOffline => {} } - if let (WarmingUp(crnt), WarmingUp(proposed)) = (self.availability, availability) { - self.availability = WarmingUp(std::cmp::max(crnt, proposed)); + if let (WarmingUp(crnt), WarmingUp(proposed)) = (&self.availability, &availability) { + self.availability = WarmingUp(std::cmp::max(*crnt, *proposed)); } else { self.availability = availability; } @@ -126,12 +126,12 @@ impl Node { /// into a description of the transition. pub(crate) fn get_availability_transition( &self, - availability: NodeAvailability, + availability: &NodeAvailability, ) -> AvailabilityTransition { use AvailabilityTransition::*; use NodeAvailability::*; - match (self.availability, availability) { + match (&self.availability, availability) { (Offline, Active(_)) => ToActive, (Active(_), Offline) => ToOffline, (Active(_), WarmingUp(_)) => ToWarmingUpFromActive, @@ -153,15 +153,15 @@ impl Node { /// Is this node elegible to have work scheduled onto it? pub(crate) fn may_schedule(&self) -> MaySchedule { - let score = match self.availability { - NodeAvailability::Active(score) => score, + let utilization = match &self.availability { + NodeAvailability::Active(u) => u.clone(), NodeAvailability::Offline | NodeAvailability::WarmingUp(_) => return MaySchedule::No, }; match self.scheduling { - NodeSchedulingPolicy::Active => MaySchedule::Yes(score), + NodeSchedulingPolicy::Active => MaySchedule::Yes(utilization), NodeSchedulingPolicy::Draining => MaySchedule::No, - NodeSchedulingPolicy::Filling => MaySchedule::Yes(score), + NodeSchedulingPolicy::Filling => MaySchedule::Yes(utilization), NodeSchedulingPolicy::Pause => MaySchedule::No, NodeSchedulingPolicy::PauseForRestart => MaySchedule::No, } @@ -285,7 +285,7 @@ impl Node { pub(crate) fn describe(&self) -> NodeDescribeResponse { NodeDescribeResponse { id: self.id, - availability: self.availability.into(), + availability: self.availability.clone().into(), scheduling: self.scheduling, listen_http_addr: self.listen_http_addr.clone(), listen_http_port: self.listen_http_port, diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index aebbdec0d1..1a905753a1 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -25,6 +25,9 @@ use crate::metrics::{ }; use crate::node::Node; +use diesel_migrations::{embed_migrations, EmbeddedMigrations}; +const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); + /// ## What do we store? /// /// The storage controller service does not store most of its state durably. @@ -72,6 +75,8 @@ pub(crate) enum DatabaseError { ConnectionPool(#[from] r2d2::Error), #[error("Logical error: {0}")] Logical(String), + #[error("Migration error: {0}")] + Migration(String), } #[derive(measured::FixedCardinalityLabel, Copy, Clone)] @@ -86,6 +91,7 @@ pub(crate) enum DatabaseOperation { Detach, ReAttach, IncrementGeneration, + PeekGenerations, ListTenantShards, InsertTenantShards, UpdateTenantShard, @@ -167,6 +173,19 @@ impl Persistence { } } + /// Execute the diesel migrations that are built into this binary + pub(crate) async fn migration_run(&self) -> DatabaseResult<()> { + use diesel_migrations::{HarnessWithOutput, MigrationHarness}; + + self.with_conn(move |conn| -> DatabaseResult<()> { + HarnessWithOutput::write_to_stdout(conn) + .run_pending_migrations(MIGRATIONS) + .map(|_| ()) + .map_err(|e| DatabaseError::Migration(e.to_string())) + }) + .await + } + /// Wraps `with_conn` in order to collect latency and error metrics async fn with_measured_conn(&self, op: DatabaseOperation, func: F) -> DatabaseResult where @@ -484,6 +503,43 @@ impl Persistence { Ok(Generation::new(g as u32)) } + /// When we want to call out to the running shards for a tenant, e.g. during timeline CRUD operations, + /// we need to know where the shard is attached, _and_ the generation, so that we can re-check the generation + /// afterwards to confirm that our timeline CRUD operation is truly persistent (it must have happened in the + /// latest generation) + /// + /// If the tenant doesn't exist, an empty vector is returned. + /// + /// Output is sorted by shard number + pub(crate) async fn peek_generations( + &self, + filter_tenant_id: TenantId, + ) -> Result, Option)>, DatabaseError> { + use crate::schema::tenant_shards::dsl::*; + let rows = self + .with_measured_conn(DatabaseOperation::PeekGenerations, move |conn| { + let result = tenant_shards + .filter(tenant_id.eq(filter_tenant_id.to_string())) + .select(TenantShardPersistence::as_select()) + .order(shard_number) + .load(conn)?; + Ok(result) + }) + .await?; + + Ok(rows + .into_iter() + .map(|p| { + ( + p.get_tenant_shard_id() + .expect("Corrupt tenant shard id in database"), + p.generation.map(|g| Generation::new(g as u32)), + p.generation_pageserver.map(|n| NodeId(n as u64)), + ) + }) + .collect()) + } + #[allow(non_local_definitions)] /// For use when updating a persistent property of a tenant, such as its config or placement_policy. /// diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index 843159010d..060e3cc6ca 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -1,6 +1,6 @@ use crate::{node::Node, tenant_shard::TenantShard}; use itertools::Itertools; -use pageserver_api::controller_api::UtilizationScore; +use pageserver_api::models::PageserverUtilization; use serde::Serialize; use std::collections::HashMap; use utils::{http::error::ApiError, id::NodeId}; @@ -20,9 +20,9 @@ impl From for ApiError { } } -#[derive(Serialize, Eq, PartialEq)] +#[derive(Serialize)] pub enum MaySchedule { - Yes(UtilizationScore), + Yes(PageserverUtilization), No, } @@ -282,6 +282,28 @@ impl Scheduler { node.shard_count -= 1; } } + + // Maybe update PageserverUtilization + match update { + RefCountUpdate::AddSecondary | RefCountUpdate::Attach => { + // Referencing the node: if this takes our shard_count above the utilzation structure's + // shard count, then artifically bump it: this ensures that the scheduler immediately + // recognizes that this node has more work on it, without waiting for the next heartbeat + // to update the utilization. + if let MaySchedule::Yes(utilization) = &mut node.may_schedule { + utilization.adjust_shard_count_max(node.shard_count as u32); + } + } + RefCountUpdate::PromoteSecondary + | RefCountUpdate::Detach + | RefCountUpdate::RemoveSecondary + | RefCountUpdate::DemoteAttached => { + // De-referencing the node: leave the utilization's shard_count at a stale higher + // value until some future heartbeat after we have physically removed this shard + // from the node: this prevents the scheduler over-optimistically trying to schedule + // more work onto the node before earlier detaches are done. + } + } } // Check if the number of shards attached to a given node is lagging below @@ -326,7 +348,18 @@ impl Scheduler { use std::collections::hash_map::Entry::*; match self.nodes.entry(node.get_id()) { Occupied(mut entry) => { - entry.get_mut().may_schedule = node.may_schedule(); + // Updates to MaySchedule are how we receive updated PageserverUtilization: adjust these values + // to account for any shards scheduled on the controller but not yet visible to the pageserver. + let mut may_schedule = node.may_schedule(); + match &mut may_schedule { + MaySchedule::Yes(utilization) => { + utilization.adjust_shard_count_max(entry.get().shard_count as u32); + } + MaySchedule::No => { // Nothing to tweak + } + } + + entry.get_mut().may_schedule = may_schedule; } Vacant(entry) => { entry.insert(SchedulerNode { @@ -363,7 +396,7 @@ impl Scheduler { let may_schedule = self .nodes .get(node_id) - .map(|n| n.may_schedule != MaySchedule::No) + .map(|n| !matches!(n.may_schedule, MaySchedule::No)) .unwrap_or(false); (*node_id, may_schedule) }) @@ -383,7 +416,7 @@ impl Scheduler { /// the same tenant on the same node. This is a soft constraint: the context will never /// cause us to fail to schedule a shard. pub(crate) fn schedule_shard( - &self, + &mut self, hard_exclude: &[NodeId], context: &ScheduleContext, ) -> Result { @@ -391,31 +424,41 @@ impl Scheduler { return Err(ScheduleError::NoPageservers); } - let mut scores: Vec<(NodeId, AffinityScore, usize, usize)> = self + let mut scores: Vec<(NodeId, AffinityScore, u64, usize)> = self .nodes - .iter() - .filter_map(|(k, v)| { - if hard_exclude.contains(k) || v.may_schedule == MaySchedule::No { - None - } else { - Some(( - *k, - context.nodes.get(k).copied().unwrap_or(AffinityScore::FREE), - v.shard_count, - v.attached_shard_count, - )) - } + .iter_mut() + .filter_map(|(k, v)| match &mut v.may_schedule { + MaySchedule::No => None, + MaySchedule::Yes(_) if hard_exclude.contains(k) => None, + MaySchedule::Yes(utilization) => Some(( + *k, + context.nodes.get(k).copied().unwrap_or(AffinityScore::FREE), + utilization.cached_score(), + v.attached_shard_count, + )), }) .collect(); + // Exclude nodes whose utilization is critically high, if there are alternatives available. This will + // cause us to violate affinity rules if it is necessary to avoid critically overloading nodes: for example + // we may place shards in the same tenant together on the same pageserver if all other pageservers are + // overloaded. + let non_overloaded_scores = scores + .iter() + .filter(|i| !PageserverUtilization::is_overloaded(i.2)) + .copied() + .collect::>(); + if !non_overloaded_scores.is_empty() { + scores = non_overloaded_scores; + } + // Sort by, in order of precedence: // 1st: Affinity score. We should never pick a higher-score node if a lower-score node is available - // 2nd: Attached shard count. Within nodes with the same affinity, we always pick the node with - // the least number of attached shards. - // 3rd: Total shard count. Within nodes with the same affinity and attached shard count, use nodes - // with the lower total shard count. + // 2nd: Utilization score (this combines shard count and disk utilization) + // 3rd: Attached shard count. When nodes have identical utilization (e.g. when populating some + // empty nodes), this acts as an anti-affinity between attached shards. // 4th: Node ID. This is a convenience to make selection deterministic in tests and empty systems. - scores.sort_by_key(|i| (i.1, i.3, i.2, i.0)); + scores.sort_by_key(|i| (i.1, i.2, i.3, i.0)); if scores.is_empty() { // After applying constraints, no pageservers were left. @@ -429,7 +472,7 @@ impl Scheduler { for (node_id, node) in &self.nodes { tracing::info!( "Node {node_id}: may_schedule={} shards={}", - node.may_schedule != MaySchedule::No, + !matches!(node.may_schedule, MaySchedule::No), node.shard_count ); } @@ -469,7 +512,7 @@ impl Scheduler { pub(crate) mod test_utils { use crate::node::Node; - use pageserver_api::controller_api::{NodeAvailability, UtilizationScore}; + use pageserver_api::{controller_api::NodeAvailability, models::utilization::test_utilization}; use std::collections::HashMap; use utils::id::NodeId; /// Test helper: synthesize the requested number of nodes, all in active state. @@ -486,7 +529,7 @@ pub(crate) mod test_utils { format!("pghost-{i}"), 5432 + i as u16, ); - node.set_availability(NodeAvailability::Active(UtilizationScore::worst())); + node.set_availability(NodeAvailability::Active(test_utilization::simple(0, 0))); assert!(node.is_available()); node }) @@ -497,6 +540,8 @@ pub(crate) mod test_utils { #[cfg(test)] mod tests { + use pageserver_api::{controller_api::NodeAvailability, models::utilization::test_utilization}; + use super::*; use crate::tenant_shard::IntentState; @@ -557,4 +602,130 @@ mod tests { Ok(()) } + + #[test] + /// Test the PageserverUtilization's contribution to scheduling algorithm + fn scheduler_utilization() { + let mut nodes = test_utils::make_test_nodes(3); + let mut scheduler = Scheduler::new(nodes.values()); + + // Need to keep these alive because they contribute to shard counts via RAII + let mut scheduled_intents = Vec::new(); + + let empty_context = ScheduleContext::default(); + + fn assert_scheduler_chooses( + expect_node: NodeId, + scheduled_intents: &mut Vec, + scheduler: &mut Scheduler, + context: &ScheduleContext, + ) { + let scheduled = scheduler.schedule_shard(&[], context).unwrap(); + let mut intent = IntentState::new(); + intent.set_attached(scheduler, Some(scheduled)); + scheduled_intents.push(intent); + assert_eq!(scheduled, expect_node); + } + + // Independent schedule calls onto empty nodes should round-robin, because each node's + // utilization's shard count is updated inline. The order is determinsitic because when all other factors are + // equal, we order by node ID. + assert_scheduler_chooses( + NodeId(1), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + assert_scheduler_chooses( + NodeId(2), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + assert_scheduler_chooses( + NodeId(3), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + + // Manually setting utilization higher should cause schedule calls to round-robin the other nodes + // which have equal utilization. + nodes + .get_mut(&NodeId(1)) + .unwrap() + .set_availability(NodeAvailability::Active(test_utilization::simple( + 10, + 1024 * 1024 * 1024, + ))); + scheduler.node_upsert(nodes.get(&NodeId(1)).unwrap()); + + assert_scheduler_chooses( + NodeId(2), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + assert_scheduler_chooses( + NodeId(3), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + assert_scheduler_chooses( + NodeId(2), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + assert_scheduler_chooses( + NodeId(3), + &mut scheduled_intents, + &mut scheduler, + &empty_context, + ); + + // The scheduler should prefer nodes with lower affinity score, + // even if they have higher utilization (as long as they aren't utilized at >100%) + let mut context_prefer_node1 = ScheduleContext::default(); + context_prefer_node1.avoid(&[NodeId(2), NodeId(3)]); + assert_scheduler_chooses( + NodeId(1), + &mut scheduled_intents, + &mut scheduler, + &context_prefer_node1, + ); + assert_scheduler_chooses( + NodeId(1), + &mut scheduled_intents, + &mut scheduler, + &context_prefer_node1, + ); + + // If a node is over-utilized, it will not be used even if affinity scores prefer it + nodes + .get_mut(&NodeId(1)) + .unwrap() + .set_availability(NodeAvailability::Active(test_utilization::simple( + 20000, + 1024 * 1024 * 1024, + ))); + scheduler.node_upsert(nodes.get(&NodeId(1)).unwrap()); + assert_scheduler_chooses( + NodeId(2), + &mut scheduled_intents, + &mut scheduler, + &context_prefer_node1, + ); + assert_scheduler_chooses( + NodeId(3), + &mut scheduled_intents, + &mut scheduler, + &context_prefer_node1, + ); + + for mut intent in scheduled_intents { + intent.clear(&mut scheduler); + } + } } diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 3459b44774..7daa1e4f5f 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -17,8 +17,9 @@ use crate::{ compute_hook::NotifyError, drain_utils::{self, TenantShardDrain, TenantShardIterator}, id_lock_map::{trace_exclusive_lock, trace_shared_lock, IdLockMap, TracingExclusiveGuard}, + leadership::Leadership, metrics, - peer_client::{GlobalObservedState, PeerClient}, + peer_client::GlobalObservedState, persistence::{ AbortShardSplitStatus, ControllerPersistence, DatabaseResult, MetadataHealthPersistence, TenantFilter, @@ -43,7 +44,7 @@ use pageserver_api::{ NodeSchedulingPolicy, PlacementPolicy, ShardSchedulingPolicy, TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard, TenantDescribeResponse, TenantDescribeResponseShard, TenantLocateResponse, TenantPolicyRequest, - TenantShardMigrateRequest, TenantShardMigrateResponse, UtilizationScore, + TenantShardMigrateRequest, TenantShardMigrateResponse, }, models::{SecondaryProgress, TenantConfigRequest, TopTenantShardsRequest}, }; @@ -287,6 +288,9 @@ pub struct Config { // This JWT token will be used to authenticate this service to the control plane. pub control_plane_jwt_token: Option, + // This JWT token will be used to authenticate with other storage controller instances + pub peer_jwt_token: Option, + /// Where the compute hook should send notifications of pageserver attachment locations /// (this URL points to the control plane in prod). If this is None, the compute hook will /// assume it is running in a test environment and try to update neon_local. @@ -333,7 +337,7 @@ impl From for ApiError { DatabaseError::Connection(_) | DatabaseError::ConnectionPool(_) => { ApiError::ShuttingDown } - DatabaseError::Logical(reason) => { + DatabaseError::Logical(reason) | DatabaseError::Migration(reason) => { ApiError::InternalServerError(anyhow::anyhow!(reason)) } } @@ -538,7 +542,7 @@ impl Service { let locked = self.inner.read().unwrap(); locked.nodes.clone() }; - let nodes_online = self.initial_heartbeat_round(all_nodes.keys()).await; + let mut nodes_online = self.initial_heartbeat_round(all_nodes.keys()).await; // List of tenants for which we will attempt to notify compute of their location at startup let mut compute_notifications = Vec::new(); @@ -552,10 +556,8 @@ impl Service { // Mark nodes online if they responded to us: nodes are offline by default after a restart. let mut new_nodes = (**nodes).clone(); for (node_id, node) in new_nodes.iter_mut() { - if let Some(utilization) = nodes_online.get(node_id) { - node.set_availability(NodeAvailability::Active(UtilizationScore( - utilization.utilization_score, - ))); + if let Some(utilization) = nodes_online.remove(node_id) { + node.set_availability(NodeAvailability::Active(utilization)); scheduler.node_upsert(node); } } @@ -606,22 +608,15 @@ impl Service { // Before making any obeservable changes to the cluster, persist self // as leader in database and memory. - if let Some(address_for_peers) = &self.config.address_for_peers { - // TODO: `address-for-peers` can become a mandatory cli arg - // after we update the k8s setup - let proposed_leader = ControllerPersistence { - address: address_for_peers.to_string(), - started_at: chrono::Utc::now(), - }; + let leadership = Leadership::new( + self.persistence.clone(), + self.config.clone(), + self.cancel.child_token(), + ); - if let Err(err) = self - .persistence - .update_leader(current_leader, proposed_leader) - .await - { - tracing::error!("Failed to persist self as leader: {err}. Aborting start-up ..."); - std::process::exit(1); - } + if let Err(e) = leadership.become_leader(current_leader).await { + tracing::error!("Failed to persist self as leader: {e}. Aborting start-up ..."); + std::process::exit(1); } self.inner.write().unwrap().become_leader(); @@ -928,9 +923,9 @@ impl Service { if let Ok(deltas) = res { for (node_id, state) in deltas.0 { let new_availability = match state { - PageserverState::Available { utilization, .. } => NodeAvailability::Active( - UtilizationScore(utilization.utilization_score), - ), + PageserverState::Available { utilization, .. } => { + NodeAvailability::Active(utilization) + } PageserverState::WarmingUp { started_at } => { NodeAvailability::WarmingUp(started_at) } @@ -939,14 +934,17 @@ impl Service { // while the heartbeat round was on-going. Hence, filter out // offline transitions for WarmingUp nodes that are still within // their grace period. - if let Ok(NodeAvailability::WarmingUp(started_at)) = - self.get_node(node_id).await.map(|n| n.get_availability()) + if let Ok(NodeAvailability::WarmingUp(started_at)) = self + .get_node(node_id) + .await + .as_ref() + .map(|n| n.get_availability()) { let now = Instant::now(); - if now - started_at >= self.config.max_warming_up_interval { + if now - *started_at >= self.config.max_warming_up_interval { NodeAvailability::Offline } else { - NodeAvailability::WarmingUp(started_at) + NodeAvailability::WarmingUp(*started_at) } } else { NodeAvailability::Offline @@ -1159,6 +1157,16 @@ impl Service { let (result_tx, result_rx) = tokio::sync::mpsc::unbounded_channel(); let (abort_tx, abort_rx) = tokio::sync::mpsc::unbounded_channel(); + let leadership_cancel = CancellationToken::new(); + let leadership = Leadership::new(persistence.clone(), config.clone(), leadership_cancel); + let (leader, leader_step_down_state) = leadership.step_down_current_leader().await?; + + // Apply the migrations **after** the current leader has stepped down + // (or we've given up waiting for it), but **before** reading from the + // database. The only exception is reading the current leader before + // migrating. + persistence.migration_run().await?; + tracing::info!("Loading nodes from database..."); let nodes = persistence .list_nodes() @@ -1376,32 +1384,6 @@ impl Service { return; }; - let leadership_status = this.inner.read().unwrap().get_leadership_status(); - let leader = match this.get_leader().await { - Ok(ok) => ok, - Err(err) => { - tracing::error!( - "Failed to query database for current leader: {err}. Aborting start-up ..." - ); - std::process::exit(1); - } - }; - - let leader_step_down_state = match leadership_status { - LeadershipStatus::Candidate => { - if let Some(ref leader) = leader { - this.request_step_down(leader).await - } else { - tracing::info!( - "No leader found to request step down from. Will build observed state." - ); - None - } - } - LeadershipStatus::Leader => None, - LeadershipStatus::SteppedDown => unreachable!(), - }; - this.startup_reconcile(leader, leader_step_down_state, bg_compute_notify_result_tx) .await; @@ -1644,7 +1626,7 @@ impl Service { // This Node is a mutable local copy: we will set it active so that we can use its // API client to reconcile with the node. The Node in [`Self::nodes`] will get updated // later. - node.set_availability(NodeAvailability::Active(UtilizationScore::worst())); + node.set_availability(NodeAvailability::Active(PageserverUtilization::full())); let configs = match node .with_client_retries( @@ -2492,7 +2474,7 @@ impl Service { .await; let node = { - let locked = self.inner.read().unwrap(); + let mut locked = self.inner.write().unwrap(); // Just a sanity check to prevent misuse: the API expects that the tenant is fully // detached everywhere, and nothing writes to S3 storage. Here, we verify that, // but only at the start of the process, so it's really just to prevent operator @@ -2519,7 +2501,7 @@ impl Service { return Err(ApiError::InternalServerError(anyhow::anyhow!("We observed attached={mode:?} tenant in node_id={node_id} shard with tenant_shard_id={shard_id}"))); } } - let scheduler = &locked.scheduler; + let scheduler = &mut locked.scheduler; // Right now we only perform the operation on a single node without parallelization // TODO fan out the operation to multiple nodes for better performance let node_id = scheduler.schedule_shard(&[], &ScheduleContext::default())?; @@ -2872,82 +2854,67 @@ impl Service { .await; failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock"); - self.ensure_attached_wait(tenant_id).await?; + self.tenant_remote_mutation(tenant_id, move |mut targets| async move { + if targets.is_empty() { + return Err(ApiError::NotFound( + anyhow::anyhow!("Tenant not found").into(), + )); + }; + let shard_zero = targets.remove(0); - let mut targets = { - let locked = self.inner.read().unwrap(); - let mut targets = Vec::new(); + async fn create_one( + tenant_shard_id: TenantShardId, + node: Node, + jwt: Option, + create_req: TimelineCreateRequest, + ) -> Result { + tracing::info!( + "Creating timeline on shard {}/{}, attached to node {node}", + tenant_shard_id, + create_req.new_timeline_id, + ); + let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); - for (tenant_shard_id, shard) in - locked.tenants.range(TenantShardId::tenant_range(tenant_id)) - { - let node_id = shard.intent.get_attached().ok_or_else(|| { - ApiError::InternalServerError(anyhow::anyhow!("Shard not scheduled")) - })?; - let node = locked - .nodes - .get(&node_id) - .expect("Pageservers may not be deleted while referenced"); - - targets.push((*tenant_shard_id, node.clone())); + client + .timeline_create(tenant_shard_id, &create_req) + .await + .map_err(|e| passthrough_api_error(&node, e)) } - targets - }; - if targets.is_empty() { - return Err(ApiError::NotFound( - anyhow::anyhow!("Tenant not found").into(), - )); - }; - let shard_zero = targets.remove(0); - - async fn create_one( - tenant_shard_id: TenantShardId, - node: Node, - jwt: Option, - create_req: TimelineCreateRequest, - ) -> Result { - tracing::info!( - "Creating timeline on shard {}/{}, attached to node {node}", - tenant_shard_id, - create_req.new_timeline_id, - ); - let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); - - client - .timeline_create(tenant_shard_id, &create_req) - .await - .map_err(|e| passthrough_api_error(&node, e)) - } - - // Because the caller might not provide an explicit LSN, we must do the creation first on a single shard, and then - // use whatever LSN that shard picked when creating on subsequent shards. We arbitrarily use shard zero as the shard - // that will get the first creation request, and propagate the LSN to all the >0 shards. - let timeline_info = create_one( - shard_zero.0, - shard_zero.1, - self.config.jwt_token.clone(), - create_req.clone(), - ) - .await?; - - // Propagate the LSN that shard zero picked, if caller didn't provide one - if create_req.ancestor_timeline_id.is_some() && create_req.ancestor_start_lsn.is_none() { - create_req.ancestor_start_lsn = timeline_info.ancestor_lsn; - } - - // Create timeline on remaining shards with number >0 - if !targets.is_empty() { - // If we had multiple shards, issue requests for the remainder now. - let jwt = &self.config.jwt_token; - self.tenant_for_shards(targets, |tenant_shard_id: TenantShardId, node: Node| { - let create_req = create_req.clone(); - Box::pin(create_one(tenant_shard_id, node, jwt.clone(), create_req)) - }) + // Because the caller might not provide an explicit LSN, we must do the creation first on a single shard, and then + // use whatever LSN that shard picked when creating on subsequent shards. We arbitrarily use shard zero as the shard + // that will get the first creation request, and propagate the LSN to all the >0 shards. + let timeline_info = create_one( + shard_zero.0, + shard_zero.1, + self.config.jwt_token.clone(), + create_req.clone(), + ) .await?; - } - Ok(timeline_info) + // Propagate the LSN that shard zero picked, if caller didn't provide one + if create_req.ancestor_timeline_id.is_some() && create_req.ancestor_start_lsn.is_none() + { + create_req.ancestor_start_lsn = timeline_info.ancestor_lsn; + } + + // Create timeline on remaining shards with number >0 + if !targets.is_empty() { + // If we had multiple shards, issue requests for the remainder now. + let jwt = &self.config.jwt_token; + self.tenant_for_shards( + targets.iter().map(|t| (t.0, t.1.clone())).collect(), + |tenant_shard_id: TenantShardId, node: Node| { + let create_req = create_req.clone(); + Box::pin(create_one(tenant_shard_id, node, jwt.clone(), create_req)) + }, + ) + .await?; + } + + Ok(timeline_info) + }) + .await? } pub(crate) async fn tenant_timeline_detach_ancestor( @@ -2964,107 +2931,87 @@ impl Service { ) .await; - self.ensure_attached_wait(tenant_id).await?; - - let targets = { - let locked = self.inner.read().unwrap(); - let mut targets = Vec::new(); - - for (tenant_shard_id, shard) in - locked.tenants.range(TenantShardId::tenant_range(tenant_id)) - { - let node_id = shard.intent.get_attached().ok_or_else(|| { - ApiError::InternalServerError(anyhow::anyhow!("Shard not scheduled")) - })?; - let node = locked - .nodes - .get(&node_id) - .expect("Pageservers may not be deleted while referenced"); - - targets.push((*tenant_shard_id, node.clone())); + self.tenant_remote_mutation(tenant_id, move |targets| async move { + if targets.is_empty() { + return Err(ApiError::NotFound( + anyhow::anyhow!("Tenant not found").into(), + )); } - targets - }; - if targets.is_empty() { - return Err(ApiError::NotFound( - anyhow::anyhow!("Tenant not found").into(), - )); - } + async fn detach_one( + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + node: Node, + jwt: Option, + ) -> Result<(ShardNumber, models::detach_ancestor::AncestorDetached), ApiError> { + tracing::info!( + "Detaching timeline on shard {tenant_shard_id}/{timeline_id}, attached to node {node}", + ); - async fn detach_one( - tenant_shard_id: TenantShardId, - timeline_id: TimelineId, - node: Node, - jwt: Option, - ) -> Result<(ShardNumber, models::detach_ancestor::AncestorDetached), ApiError> { - tracing::info!( - "Detaching timeline on shard {tenant_shard_id}/{timeline_id}, attached to node {node}", - ); + let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); - let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); + client + .timeline_detach_ancestor(tenant_shard_id, timeline_id) + .await + .map_err(|e| { + use mgmt_api::Error; - client - .timeline_detach_ancestor(tenant_shard_id, timeline_id) - .await - .map_err(|e| { - use mgmt_api::Error; - - match e { - // no ancestor (ever) - Error::ApiError(StatusCode::CONFLICT, msg) => ApiError::Conflict(format!( - "{node}: {}", - msg.strip_prefix("Conflict: ").unwrap_or(&msg) - )), - // too many ancestors - Error::ApiError(StatusCode::BAD_REQUEST, msg) => { - ApiError::BadRequest(anyhow::anyhow!("{node}: {msg}")) + match e { + // no ancestor (ever) + Error::ApiError(StatusCode::CONFLICT, msg) => ApiError::Conflict(format!( + "{node}: {}", + msg.strip_prefix("Conflict: ").unwrap_or(&msg) + )), + // too many ancestors + Error::ApiError(StatusCode::BAD_REQUEST, msg) => { + ApiError::BadRequest(anyhow::anyhow!("{node}: {msg}")) + } + Error::ApiError(StatusCode::INTERNAL_SERVER_ERROR, msg) => { + // avoid turning these into conflicts to remain compatible with + // pageservers, 500 errors are sadly retryable with timeline ancestor + // detach + ApiError::InternalServerError(anyhow::anyhow!("{node}: {msg}")) + } + // rest can be mapped as usual + other => passthrough_api_error(&node, other), } - Error::ApiError(StatusCode::INTERNAL_SERVER_ERROR, msg) => { - // avoid turning these into conflicts to remain compatible with - // pageservers, 500 errors are sadly retryable with timeline ancestor - // detach - ApiError::InternalServerError(anyhow::anyhow!("{node}: {msg}")) - } - // rest can be mapped as usual - other => passthrough_api_error(&node, other), - } + }) + .map(|res| (tenant_shard_id.shard_number, res)) + } + + // no shard needs to go first/last; the operation should be idempotent + let mut results = self + .tenant_for_shards(targets, |tenant_shard_id, node| { + futures::FutureExt::boxed(detach_one( + tenant_shard_id, + timeline_id, + node, + self.config.jwt_token.clone(), + )) }) - .map(|res| (tenant_shard_id.shard_number, res)) - } + .await?; - // no shard needs to go first/last; the operation should be idempotent - let mut results = self - .tenant_for_shards(targets, |tenant_shard_id, node| { - futures::FutureExt::boxed(detach_one( - tenant_shard_id, - timeline_id, - node, - self.config.jwt_token.clone(), - )) - }) - .await?; + let any = results.pop().expect("we must have at least one response"); - let any = results.pop().expect("we must have at least one response"); + let mismatching = results + .iter() + .filter(|(_, res)| res != &any.1) + .collect::>(); + if !mismatching.is_empty() { + // this can be hit by races which should not happen because operation lock on cplane + let matching = results.len() - mismatching.len(); + tracing::error!( + matching, + compared_against=?any, + ?mismatching, + "shards returned different results" + ); - let mismatching = results - .iter() - .filter(|(_, res)| res != &any.1) - .collect::>(); - if !mismatching.is_empty() { - // this can be hit by races which should not happen because operation lock on cplane - let matching = results.len() - mismatching.len(); - tracing::error!( - matching, - compared_against=?any, - ?mismatching, - "shards returned different results" - ); + return Err(ApiError::InternalServerError(anyhow::anyhow!("pageservers returned mixed results for ancestor detach; manual intervention is required."))); + } - return Err(ApiError::InternalServerError(anyhow::anyhow!("pageservers returned mixed results for ancestor detach; manual intervention is required."))); - } - - Ok(any.1) + Ok(any.1) + }).await? } /// Helper for concurrently calling a pageserver API on a number of shards, such as timeline creation. @@ -3135,6 +3082,84 @@ impl Service { results } + /// Helper for safely working with the shards in a tenant remotely on pageservers, for example + /// when creating and deleting timelines: + /// - Makes sure shards are attached somewhere if they weren't already + /// - Looks up the shards and the nodes where they were most recently attached + /// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this + /// ensures that the remote operation acted on the most recent generation, and is therefore durable. + async fn tenant_remote_mutation( + &self, + tenant_id: TenantId, + op: O, + ) -> Result + where + O: FnOnce(Vec<(TenantShardId, Node)>) -> F, + F: std::future::Future, + { + let target_gens = { + let mut targets = Vec::new(); + + // Load the currently attached pageservers for the latest generation of each shard. This can + // run concurrently with reconciliations, and it is not guaranteed that the node we find here + // will still be the latest when we're done: we will check generations again at the end of + // this function to handle that. + let generations = self.persistence.peek_generations(tenant_id).await?; + let generations = if generations.iter().any(|i| i.1.is_none()) { + // One or more shards is not attached to anything: maybe this is a new tenant? Wait for + // it to reconcile. + self.ensure_attached_wait(tenant_id).await?; + self.persistence.peek_generations(tenant_id).await? + } else { + generations + }; + + let locked = self.inner.read().unwrap(); + for (tenant_shard_id, generation, generation_pageserver) in generations { + let node_id = generation_pageserver.ok_or(ApiError::Conflict( + "Tenant not currently attached".to_string(), + ))?; + let node = locked + .nodes + .get(&node_id) + .ok_or(ApiError::Conflict(format!( + "Raced with removal of node {node_id}" + )))?; + targets.push((tenant_shard_id, node.clone(), generation)); + } + + targets + }; + + let targets = target_gens.iter().map(|t| (t.0, t.1.clone())).collect(); + let result = op(targets).await; + + // Post-check: are all the generations of all the shards the same as they were initially? This proves that + // our remote operation executed on the latest generation and is therefore persistent. + { + let latest_generations = self.persistence.peek_generations(tenant_id).await?; + if latest_generations + .into_iter() + .map(|g| (g.0, g.1)) + .collect::>() + != target_gens + .into_iter() + .map(|i| (i.0, i.2)) + .collect::>() + { + // We raced with something that incremented the generation, and therefore cannot be + // confident that our actions are persistent (they might have hit an old generation). + // + // This is safe but requires a retry: ask the client to do that by giving them a 503 response. + return Err(ApiError::ResourceUnavailable( + "Tenant attachment changed, please retry".into(), + )); + } + } + + Ok(result) + } + pub(crate) async fn tenant_timeline_delete( &self, tenant_id: TenantId, @@ -3148,83 +3173,62 @@ impl Service { ) .await; - self.ensure_attached_wait(tenant_id).await?; - - let mut targets = { - let locked = self.inner.read().unwrap(); - let mut targets = Vec::new(); - - for (tenant_shard_id, shard) in - locked.tenants.range(TenantShardId::tenant_range(tenant_id)) - { - let node_id = shard.intent.get_attached().ok_or_else(|| { - ApiError::InternalServerError(anyhow::anyhow!("Shard not scheduled")) - })?; - let node = locked - .nodes - .get(&node_id) - .expect("Pageservers may not be deleted while referenced"); - - targets.push((*tenant_shard_id, node.clone())); + self.tenant_remote_mutation(tenant_id, move |mut targets| async move { + if targets.is_empty() { + return Err(ApiError::NotFound( + anyhow::anyhow!("Tenant not found").into(), + )); } - targets - }; + let shard_zero = targets.remove(0); - if targets.is_empty() { - return Err(ApiError::NotFound( - anyhow::anyhow!("Tenant not found").into(), - )); - } - let shard_zero = targets.remove(0); + async fn delete_one( + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + node: Node, + jwt: Option, + ) -> Result { + tracing::info!( + "Deleting timeline on shard {tenant_shard_id}/{timeline_id}, attached to node {node}", + ); - async fn delete_one( - tenant_shard_id: TenantShardId, - timeline_id: TimelineId, - node: Node, - jwt: Option, - ) -> Result { - tracing::info!( - "Deleting timeline on shard {tenant_shard_id}/{timeline_id}, attached to node {node}", - ); + let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); + client + .timeline_delete(tenant_shard_id, timeline_id) + .await + .map_err(|e| { + ApiError::InternalServerError(anyhow::anyhow!( + "Error deleting timeline {timeline_id} on {tenant_shard_id} on node {node}: {e}", + )) + }) + } - let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref()); - client - .timeline_delete(tenant_shard_id, timeline_id) - .await - .map_err(|e| { - ApiError::InternalServerError(anyhow::anyhow!( - "Error deleting timeline {timeline_id} on {tenant_shard_id} on node {node}: {e}", + let statuses = self + .tenant_for_shards(targets, |tenant_shard_id: TenantShardId, node: Node| { + Box::pin(delete_one( + tenant_shard_id, + timeline_id, + node, + self.config.jwt_token.clone(), )) }) - } + .await?; - let statuses = self - .tenant_for_shards(targets, |tenant_shard_id: TenantShardId, node: Node| { - Box::pin(delete_one( - tenant_shard_id, - timeline_id, - node, - self.config.jwt_token.clone(), - )) - }) + // If any shards >0 haven't finished deletion yet, don't start deletion on shard zero + if statuses.iter().any(|s| s != &StatusCode::NOT_FOUND) { + return Ok(StatusCode::ACCEPTED); + } + + // Delete shard zero last: this is not strictly necessary, but since a caller's GET on a timeline will be routed + // to shard zero, it gives a more obvious behavior that a GET returns 404 once the deletion is done. + let shard_zero_status = delete_one( + shard_zero.0, + timeline_id, + shard_zero.1, + self.config.jwt_token.clone(), + ) .await?; - - // If any shards >0 haven't finished deletion yet, don't start deletion on shard zero - if statuses.iter().any(|s| s != &StatusCode::NOT_FOUND) { - return Ok(StatusCode::ACCEPTED); - } - - // Delete shard zero last: this is not strictly necessary, but since a caller's GET on a timeline will be routed - // to shard zero, it gives a more obvious behavior that a GET returns 404 once the deletion is done. - let shard_zero_status = delete_one( - shard_zero.0, - timeline_id, - shard_zero.1, - self.config.jwt_token.clone(), - ) - .await?; - - Ok(shard_zero_status) + Ok(shard_zero_status) + }).await? } /// When you need to send an HTTP request to the pageserver that holds shard0 of a tenant, this @@ -4780,7 +4784,7 @@ impl Service { // // The transition we calculate here remains valid later in the function because we hold the op lock on the node: // nothing else can mutate its availability while we run. - let availability_transition = if let Some(input_availability) = availability { + let availability_transition = if let Some(input_availability) = availability.as_ref() { let (activate_node, availability_transition) = { let locked = self.inner.read().unwrap(); let Some(node) = locked.nodes.get(&node_id) else { @@ -4816,8 +4820,8 @@ impl Service { )); }; - if let Some(availability) = &availability { - node.set_availability(*availability); + if let Some(availability) = availability.as_ref() { + node.set_availability(availability.clone()); } if let Some(scheduling) = scheduling { @@ -6377,42 +6381,4 @@ impl Service { global_observed } - - /// Request step down from the currently registered leader in the database - /// - /// If such an entry is persisted, the success path returns the observed - /// state and details of the leader. Otherwise, None is returned indicating - /// there is no leader currently. - /// - /// On failures to query the database or step down error responses the process is killed - /// and we rely on k8s to retry. - async fn request_step_down( - &self, - leader: &ControllerPersistence, - ) -> Option { - tracing::info!("Sending step down request to {leader:?}"); - - // TODO: jwt token - let client = PeerClient::new( - Uri::try_from(leader.address.as_str()).expect("Failed to build leader URI"), - self.config.jwt_token.clone(), - ); - let state = client.step_down(&self.cancel).await; - match state { - Ok(state) => Some(state), - Err(err) => { - // TODO: Make leaders periodically update a timestamp field in the - // database and, if the leader is not reachable from the current instance, - // but inferred as alive from the timestamp, abort start-up. This avoids - // a potential scenario in which we have two controllers acting as leaders. - tracing::error!( - "Leader ({}) did not respond to step-down request: {}", - leader.address, - err - ); - - None - } - } - } } diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index 1fcc3c8547..30723a3b36 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -779,7 +779,7 @@ impl TenantShard { #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))] pub(crate) fn optimize_secondary( &self, - scheduler: &Scheduler, + scheduler: &mut Scheduler, schedule_context: &ScheduleContext, ) -> Option { if self.intent.secondary.is_empty() { @@ -1595,7 +1595,7 @@ pub(crate) mod tests { schedule_context.avoid(&shard_b.intent.all_pageservers()); schedule_context.push_attached(shard_b.intent.get_attached().unwrap()); - let optimization_a = shard_a.optimize_secondary(&scheduler, &schedule_context); + let optimization_a = shard_a.optimize_secondary(&mut scheduler, &schedule_context); // Since there is a node with no locations available, the node with two locations for the // same tenant should generate an optimization to move one away diff --git a/storage_scrubber/README.md b/storage_scrubber/README.md index 9fbd92feef..5be8541419 100644 --- a/storage_scrubber/README.md +++ b/storage_scrubber/README.md @@ -98,7 +98,7 @@ to list timelines and find their backup and start LSNs. If S3 state is altered first manually, pageserver in-memory state will contain wrong data about S3 state, and tenants/timelines may get recreated on S3 (due to any layer upload due to compaction, pageserver restart, etc.). So before proceeding, for tenants/timelines which are already deleted in the console, we must remove these from pageservers. -First, we need to group pageservers by buckets, `https:///admin/pageservers`` can be used for all env nodes, then `cat /storage/pageserver/data/pageserver.toml` on every node will show the bucket names and regions needed. +First, we need to group pageservers by buckets, `https:///admin/pageservers` can be used for all env nodes, then `cat /storage/pageserver/data/pageserver.toml` on every node will show the bucket names and regions needed. Per bucket, for every pageserver id related, find deleted tenants: diff --git a/storage_scrubber/src/checks.rs b/storage_scrubber/src/checks.rs index 35ec69fd50..08b0f06ebf 100644 --- a/storage_scrubber/src/checks.rs +++ b/storage_scrubber/src/checks.rs @@ -1,10 +1,10 @@ use std::collections::{HashMap, HashSet}; use anyhow::Context; -use aws_sdk_s3::Client; use pageserver::tenant::layer_map::LayerMap; use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata; use pageserver_api::shard::ShardIndex; +use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; use utils::generation::Generation; use utils::id::TimelineId; @@ -16,7 +16,7 @@ use futures_util::StreamExt; use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path}; use pageserver::tenant::storage_layer::LayerName; use pageserver::tenant::IndexPart; -use remote_storage::RemotePath; +use remote_storage::{GenericRemoteStorage, ListingObject, RemotePath}; pub(crate) struct TimelineAnalysis { /// Anomalies detected @@ -48,13 +48,12 @@ impl TimelineAnalysis { } pub(crate) async fn branch_cleanup_and_check_errors( - s3_client: &Client, - target: &RootTarget, + remote_client: &GenericRemoteStorage, id: &TenantShardTimelineId, tenant_objects: &mut TenantObjectListing, s3_active_branch: Option<&BranchData>, console_branch: Option, - s3_data: Option, + s3_data: Option, ) -> TimelineAnalysis { let mut result = TimelineAnalysis::new(); @@ -78,7 +77,9 @@ pub(crate) async fn branch_cleanup_and_check_errors( match s3_data { Some(s3_data) => { - result.garbage_keys.extend(s3_data.unknown_keys); + result + .garbage_keys + .extend(s3_data.unknown_keys.into_iter().map(|k| k.key.to_string())); match s3_data.blob_data { BlobDataParseResult::Parsed { @@ -143,16 +144,13 @@ pub(crate) async fn branch_cleanup_and_check_errors( // HEAD request used here to address a race condition when an index was uploaded concurrently // with our scan. We check if the object is uploaded to S3 after taking the listing snapshot. - let response = s3_client - .head_object() - .bucket(target.bucket_name()) - .key(path.get_path().as_str()) - .send() + let response = remote_client + .head_object(&path, &CancellationToken::new()) .await; if response.is_err() { // Object is not present. - let is_l0 = LayerMap::is_l0(layer.key_range()); + let is_l0 = LayerMap::is_l0(layer.key_range(), layer.is_delta()); let msg = format!( "index_part.json contains a layer {}{} (shard {}) that is not present in remote storage (layer_is_l0: {})", @@ -284,14 +282,14 @@ impl TenantObjectListing { } #[derive(Debug)] -pub(crate) struct S3TimelineBlobData { +pub(crate) struct RemoteTimelineBlobData { pub(crate) blob_data: BlobDataParseResult, // Index objects that were not used when loading `blob_data`, e.g. those from old generations - pub(crate) unused_index_keys: Vec, + pub(crate) unused_index_keys: Vec, // Objects whose keys were not recognized at all, i.e. not layer files, not indices - pub(crate) unknown_keys: Vec, + pub(crate) unknown_keys: Vec, } #[derive(Debug)] @@ -323,31 +321,37 @@ pub(crate) fn parse_layer_object_name(name: &str) -> Result<(LayerName, Generati } pub(crate) async fn list_timeline_blobs( - s3_client: &Client, + remote_client: &GenericRemoteStorage, id: TenantShardTimelineId, - s3_root: &RootTarget, -) -> anyhow::Result { + root_target: &RootTarget, +) -> anyhow::Result { let mut s3_layers = HashSet::new(); let mut errors = Vec::new(); let mut unknown_keys = Vec::new(); - let mut timeline_dir_target = s3_root.timeline_root(&id); + let mut timeline_dir_target = root_target.timeline_root(&id); timeline_dir_target.delimiter = String::new(); - let mut index_part_keys: Vec = Vec::new(); + let mut index_part_keys: Vec = Vec::new(); let mut initdb_archive: bool = false; - let mut stream = std::pin::pin!(stream_listing(s3_client, &timeline_dir_target)); - while let Some(obj) = stream.next().await { - let obj = obj?; - let key = obj.key(); + let prefix_str = &timeline_dir_target + .prefix_in_bucket + .strip_prefix("/") + .unwrap_or(&timeline_dir_target.prefix_in_bucket); - let blob_name = key.strip_prefix(&timeline_dir_target.prefix_in_bucket); + let mut stream = std::pin::pin!(stream_listing(remote_client, &timeline_dir_target)); + while let Some(obj) = stream.next().await { + let (key, Some(obj)) = obj? else { + panic!("ListingObject not specified"); + }; + + let blob_name = key.get_path().as_str().strip_prefix(prefix_str); match blob_name { Some(name) if name.starts_with("index_part.json") => { tracing::debug!("Index key {key}"); - index_part_keys.push(key.to_owned()) + index_part_keys.push(obj) } Some("initdb.tar.zst") => { tracing::debug!("initdb archive {key}"); @@ -358,7 +362,7 @@ pub(crate) async fn list_timeline_blobs( } Some(maybe_layer_name) => match parse_layer_object_name(maybe_layer_name) { Ok((new_layer, gen)) => { - tracing::debug!("Parsed layer key: {} {:?}", new_layer, gen); + tracing::debug!("Parsed layer key: {new_layer} {gen:?}"); s3_layers.insert((new_layer, gen)); } Err(e) => { @@ -366,13 +370,13 @@ pub(crate) async fn list_timeline_blobs( errors.push( format!("S3 list response got an object with key {key} that is not a layer name: {e}"), ); - unknown_keys.push(key.to_string()); + unknown_keys.push(obj); } }, None => { - tracing::warn!("Unknown key {}", key); + tracing::warn!("Unknown key {key}"); errors.push(format!("S3 list response got an object with odd key {key}")); - unknown_keys.push(key.to_string()); + unknown_keys.push(obj); } } } @@ -381,7 +385,7 @@ pub(crate) async fn list_timeline_blobs( tracing::debug!( "Timeline is empty apart from initdb archive: expected post-deletion state." ); - return Ok(S3TimelineBlobData { + return Ok(RemoteTimelineBlobData { blob_data: BlobDataParseResult::Relic, unused_index_keys: index_part_keys, unknown_keys: Vec::new(), @@ -395,13 +399,13 @@ pub(crate) async fn list_timeline_blobs( // Stripping the index key to the last part, because RemotePath doesn't // like absolute paths, and depending on prefix_in_bucket it's possible // for the keys we read back to start with a slash. - let basename = key.rsplit_once('/').unwrap().1; + let basename = key.key.get_path().as_str().rsplit_once('/').unwrap().1; parse_remote_index_path(RemotePath::from_string(basename).unwrap()).map(|g| (key, g)) }) .max_by_key(|i| i.1) .map(|(k, g)| (k.clone(), g)) { - Some((key, gen)) => (Some(key), gen), + Some((key, gen)) => (Some::(key.to_owned()), gen), None => { // Legacy/missing case: one or zero index parts, which did not have a generation (index_part_keys.pop(), Generation::none()) @@ -416,17 +420,14 @@ pub(crate) async fn list_timeline_blobs( } if let Some(index_part_object_key) = index_part_object.as_ref() { - let index_part_bytes = download_object_with_retries( - s3_client, - &timeline_dir_target.bucket_name, - index_part_object_key, - ) - .await - .context("index_part.json download")?; + let index_part_bytes = + download_object_with_retries(remote_client, &index_part_object_key.key) + .await + .context("index_part.json download")?; match serde_json::from_slice(&index_part_bytes) { Ok(index_part) => { - return Ok(S3TimelineBlobData { + return Ok(RemoteTimelineBlobData { blob_data: BlobDataParseResult::Parsed { index_part: Box::new(index_part), index_part_generation, @@ -448,7 +449,7 @@ pub(crate) async fn list_timeline_blobs( ); } - Ok(S3TimelineBlobData { + Ok(RemoteTimelineBlobData { blob_data: BlobDataParseResult::Incorrect { errors, s3_layers }, unused_index_keys: index_part_keys, unknown_keys, diff --git a/storage_scrubber/src/find_large_objects.rs b/storage_scrubber/src/find_large_objects.rs index f5bb7e088a..88e36af560 100644 --- a/storage_scrubber/src/find_large_objects.rs +++ b/storage_scrubber/src/find_large_objects.rs @@ -6,7 +6,7 @@ use remote_storage::ListingMode; use serde::{Deserialize, Serialize}; use crate::{ - checks::parse_layer_object_name, init_remote_generic, metadata_stream::stream_tenants_generic, + checks::parse_layer_object_name, init_remote, metadata_stream::stream_tenants, stream_objects_with_retries, BucketConfig, NodeKind, }; @@ -50,9 +50,8 @@ pub async fn find_large_objects( ignore_deltas: bool, concurrency: usize, ) -> anyhow::Result { - let (remote_client, target) = - init_remote_generic(bucket_config.clone(), NodeKind::Pageserver).await?; - let tenants = pin!(stream_tenants_generic(&remote_client, &target)); + let (remote_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?; + let tenants = pin!(stream_tenants(&remote_client, &target)); let objects_stream = tenants.map_ok(|tenant_shard_id| { let mut tenant_root = target.tenant_root(&tenant_shard_id); diff --git a/storage_scrubber/src/garbage.rs b/storage_scrubber/src/garbage.rs index d6a73bf366..3e22960f8d 100644 --- a/storage_scrubber/src/garbage.rs +++ b/storage_scrubber/src/garbage.rs @@ -19,8 +19,8 @@ use utils::id::TenantId; use crate::{ cloud_admin_api::{CloudAdminApiClient, MaybeDeleted, ProjectData}, - init_remote_generic, list_objects_with_retries_generic, - metadata_stream::{stream_tenant_timelines_generic, stream_tenants_generic}, + init_remote, list_objects_with_retries, + metadata_stream::{stream_tenant_timelines, stream_tenants}, BucketConfig, ConsoleConfig, NodeKind, TenantShardTimelineId, TraversingDepth, }; @@ -153,7 +153,7 @@ async fn find_garbage_inner( node_kind: NodeKind, ) -> anyhow::Result { // Construct clients for S3 and for Console API - let (remote_client, target) = init_remote_generic(bucket_config.clone(), node_kind).await?; + let (remote_client, target) = init_remote(bucket_config.clone(), node_kind).await?; let cloud_admin_api_client = Arc::new(CloudAdminApiClient::new(console_config)); // Build a set of console-known tenants, for quickly eliminating known-active tenants without having @@ -179,7 +179,7 @@ async fn find_garbage_inner( // Enumerate Tenants in S3, and check if each one exists in Console tracing::info!("Finding all tenants in bucket {}...", bucket_config.bucket); - let tenants = stream_tenants_generic(&remote_client, &target); + let tenants = stream_tenants(&remote_client, &target); let tenants_checked = tenants.map_ok(|t| { let api_client = cloud_admin_api_client.clone(); let console_cache = console_cache.clone(); @@ -237,14 +237,13 @@ async fn find_garbage_inner( // Special case: If it's missing in console, check for known bugs that would enable us to conclusively // identify it as purge-able anyway if console_result.is_none() { - let timelines = - stream_tenant_timelines_generic(&remote_client, &target, tenant_shard_id) - .await? - .collect::>() - .await; + let timelines = stream_tenant_timelines(&remote_client, &target, tenant_shard_id) + .await? + .collect::>() + .await; if timelines.is_empty() { // No timelines, but a heatmap: the deletion bug where we deleted everything but heatmaps - let tenant_objects = list_objects_with_retries_generic( + let tenant_objects = list_objects_with_retries( &remote_client, ListingMode::WithDelimiter, &target.tenant_root(&tenant_shard_id), @@ -265,7 +264,7 @@ async fn find_garbage_inner( for timeline_r in timelines { let timeline = timeline_r?; - let timeline_objects = list_objects_with_retries_generic( + let timeline_objects = list_objects_with_retries( &remote_client, ListingMode::WithDelimiter, &target.timeline_root(&timeline), @@ -331,8 +330,7 @@ async fn find_garbage_inner( // Construct a stream of all timelines within active tenants let active_tenants = tokio_stream::iter(active_tenants.iter().map(Ok)); - let timelines = - active_tenants.map_ok(|t| stream_tenant_timelines_generic(&remote_client, &target, *t)); + let timelines = active_tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, *t)); let timelines = timelines.try_buffer_unordered(S3_CONCURRENCY); let timelines = timelines.try_flatten(); @@ -507,7 +505,7 @@ pub async fn purge_garbage( ); let (remote_client, _target) = - init_remote_generic(garbage_list.bucket_config.clone(), garbage_list.node_kind).await?; + init_remote(garbage_list.bucket_config.clone(), garbage_list.node_kind).await?; assert_eq!( &garbage_list.bucket_config.bucket, diff --git a/storage_scrubber/src/lib.rs b/storage_scrubber/src/lib.rs index 1fc94cc174..112f052e07 100644 --- a/storage_scrubber/src/lib.rs +++ b/storage_scrubber/src/lib.rs @@ -15,7 +15,7 @@ use std::fmt::Display; use std::sync::Arc; use std::time::Duration; -use anyhow::{anyhow, Context}; +use anyhow::Context; use aws_config::retry::{RetryConfigBuilder, RetryMode}; use aws_sdk_s3::config::Region; use aws_sdk_s3::error::DisplayErrorContext; @@ -352,7 +352,7 @@ fn make_root_target( } } -async fn init_remote( +async fn init_remote_s3( bucket_config: BucketConfig, node_kind: NodeKind, ) -> anyhow::Result<(Arc, RootTarget)> { @@ -369,7 +369,7 @@ async fn init_remote( Ok((s3_client, s3_root)) } -async fn init_remote_generic( +async fn init_remote( bucket_config: BucketConfig, node_kind: NodeKind, ) -> anyhow::Result<(GenericRemoteStorage, RootTarget)> { @@ -394,45 +394,10 @@ async fn init_remote_generic( // We already pass the prefix to the remote client above let prefix_in_root_target = String::new(); - let s3_root = make_root_target(bucket_config.bucket, prefix_in_root_target, node_kind); + let root_target = make_root_target(bucket_config.bucket, prefix_in_root_target, node_kind); let client = GenericRemoteStorage::from_config(&storage_config).await?; - Ok((client, s3_root)) -} - -async fn list_objects_with_retries( - s3_client: &Client, - s3_target: &S3Target, - continuation_token: Option, -) -> anyhow::Result { - for trial in 0..MAX_RETRIES { - match s3_client - .list_objects_v2() - .bucket(&s3_target.bucket_name) - .prefix(&s3_target.prefix_in_bucket) - .delimiter(&s3_target.delimiter) - .set_continuation_token(continuation_token.clone()) - .send() - .await - { - Ok(response) => return Ok(response), - Err(e) => { - if trial == MAX_RETRIES - 1 { - return Err(e) - .with_context(|| format!("Failed to list objects {MAX_RETRIES} times")); - } - error!( - "list_objects_v2 query failed: bucket_name={}, prefix={}, delimiter={}, error={}", - s3_target.bucket_name, - s3_target.prefix_in_bucket, - s3_target.delimiter, - DisplayErrorContext(e), - ); - tokio::time::sleep(Duration::from_secs(1)).await; - } - } - } - Err(anyhow!("unreachable unless MAX_RETRIES==0")) + Ok((client, root_target)) } /// Listing possibly large amounts of keys in a streaming fashion. @@ -452,23 +417,26 @@ fn stream_objects_with_retries<'a>( let mut list_stream = storage_client.list_streaming(Some(&prefix), listing_mode, None, &cancel); while let Some(res) = list_stream.next().await { - if let Err(err) = res { - let yield_err = if err.is_permanent() { - true - } else { - let backoff_time = 1 << trial.max(5); - tokio::time::sleep(Duration::from_secs(backoff_time)).await; - trial += 1; - trial == MAX_RETRIES - 1 - }; - if yield_err { - yield Err(err) - .with_context(|| format!("Failed to list objects {MAX_RETRIES} times")); - break; + match res { + Err(err) => { + let yield_err = if err.is_permanent() { + true + } else { + let backoff_time = 1 << trial.max(5); + tokio::time::sleep(Duration::from_secs(backoff_time)).await; + trial += 1; + trial == MAX_RETRIES - 1 + }; + if yield_err { + yield Err(err) + .with_context(|| format!("Failed to list objects {MAX_RETRIES} times")); + break; + } + } + Ok(res) => { + trial = 0; + yield Ok(res); } - } else { - trial = 0; - yield res.map_err(anyhow::Error::from); } } } @@ -476,7 +444,7 @@ fn stream_objects_with_retries<'a>( /// If you want to list a bounded amount of prefixes or keys. For larger numbers of keys/prefixes, /// use [`stream_objects_with_retries`] instead. -async fn list_objects_with_retries_generic( +async fn list_objects_with_retries( remote_client: &GenericRemoteStorage, listing_mode: ListingMode, s3_target: &S3Target, @@ -514,40 +482,34 @@ async fn list_objects_with_retries_generic( } async fn download_object_with_retries( - s3_client: &Client, - bucket_name: &str, - key: &str, + remote_client: &GenericRemoteStorage, + key: &RemotePath, ) -> anyhow::Result> { - for _ in 0..MAX_RETRIES { - let mut body_buf = Vec::new(); - let response_stream = match s3_client - .get_object() - .bucket(bucket_name) - .key(key) - .send() - .await - { + let cancel = CancellationToken::new(); + for trial in 0..MAX_RETRIES { + let mut buf = Vec::new(); + let download = match remote_client.download(key, &cancel).await { Ok(response) => response, Err(e) => { error!("Failed to download object for key {key}: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; + let backoff_time = 1 << trial.max(5); + tokio::time::sleep(Duration::from_secs(backoff_time)).await; continue; } }; - match response_stream - .body - .into_async_read() - .read_to_end(&mut body_buf) + match tokio_util::io::StreamReader::new(download.download_stream) + .read_to_end(&mut buf) .await { Ok(bytes_read) => { tracing::debug!("Downloaded {bytes_read} bytes for object {key}"); - return Ok(body_buf); + return Ok(buf); } Err(e) => { error!("Failed to stream object body for key {key}: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; + let backoff_time = 1 << trial.max(5); + tokio::time::sleep(Duration::from_secs(backoff_time)).await; } } } @@ -555,7 +517,7 @@ async fn download_object_with_retries( anyhow::bail!("Failed to download objects with key {key} {MAX_RETRIES} times") } -async fn download_object_to_file( +async fn download_object_to_file_s3( s3_client: &Client, bucket_name: &str, key: &str, diff --git a/storage_scrubber/src/main.rs b/storage_scrubber/src/main.rs index cbc836755a..3935e513e3 100644 --- a/storage_scrubber/src/main.rs +++ b/storage_scrubber/src/main.rs @@ -3,9 +3,10 @@ use camino::Utf8PathBuf; use pageserver_api::controller_api::{MetadataHealthUpdateRequest, MetadataHealthUpdateResponse}; use pageserver_api::shard::TenantShardId; use reqwest::{Method, Url}; +use storage_controller_client::control_api; use storage_scrubber::garbage::{find_garbage, purge_garbage, PurgeMode}; use storage_scrubber::pageserver_physical_gc::GcMode; -use storage_scrubber::scan_pageserver_metadata::scan_metadata; +use storage_scrubber::scan_pageserver_metadata::scan_pageserver_metadata; use storage_scrubber::tenant_snapshot::SnapshotDownloader; use storage_scrubber::{find_large_objects, ControllerClientConfig}; use storage_scrubber::{ @@ -68,7 +69,7 @@ enum Command { #[arg(long = "tenant-id", num_args = 0..)] tenant_ids: Vec, #[arg(long = "post", default_value_t = false)] - post_to_storage_controller: bool, + post_to_storcon: bool, #[arg(long, default_value = None)] /// For safekeeper node_kind only, points to db with debug dump dump_db_connstr: Option, @@ -100,6 +101,16 @@ enum Command { #[arg(long = "concurrency", short = 'j', default_value_t = 64)] concurrency: usize, }, + CronJob { + // PageserverPhysicalGc + #[arg(long = "min-age")] + gc_min_age: humantime::Duration, + #[arg(short, long, default_value_t = GcMode::IndicesOnly)] + gc_mode: GcMode, + // ScanMetadata + #[arg(long = "post", default_value_t = false)] + post_to_storcon: bool, + }, } #[tokio::main] @@ -117,6 +128,7 @@ async fn main() -> anyhow::Result<()> { Command::TenantSnapshot { .. } => "tenant-snapshot", Command::PageserverPhysicalGc { .. } => "pageserver-physical-gc", Command::FindLargeObjects { .. } => "find-large-objects", + Command::CronJob { .. } => "cron-job", }; let _guard = init_logging(&format!( "{}_{}_{}_{}.log", @@ -126,12 +138,13 @@ async fn main() -> anyhow::Result<()> { chrono::Utc::now().format("%Y_%m_%d__%H_%M_%S") )); - let controller_client_conf = cli.controller_api.map(|controller_api| { + let controller_client = cli.controller_api.map(|controller_api| { ControllerClientConfig { controller_api, // Default to no key: this is a convenience when working in a development environment controller_jwt: cli.controller_jwt.unwrap_or("".to_owned()), } + .build_client() }); match cli.command { @@ -139,7 +152,7 @@ async fn main() -> anyhow::Result<()> { json, tenant_ids, node_kind, - post_to_storage_controller, + post_to_storcon, dump_db_connstr, dump_db_table, } => { @@ -178,53 +191,14 @@ async fn main() -> anyhow::Result<()> { } Ok(()) } else { - if controller_client_conf.is_none() && post_to_storage_controller { - return Err(anyhow!("Posting pageserver scan health status to storage controller requires `--controller-api` and `--controller-jwt` to run")); - } - match scan_metadata(bucket_config.clone(), tenant_ids).await { - Err(e) => { - tracing::error!("Failed: {e}"); - Err(e) - } - Ok(summary) => { - if json { - println!("{}", serde_json::to_string(&summary).unwrap()) - } else { - println!("{}", summary.summary_string()); - } - - if post_to_storage_controller { - if let Some(conf) = controller_client_conf { - let controller_client = conf.build_client(); - let body = summary.build_health_update_request(); - controller_client - .dispatch::( - Method::POST, - "control/v1/metadata_health/update".to_string(), - Some(body), - ) - .await?; - } - } - - if summary.is_fatal() { - tracing::error!("Fatal scrub errors detected"); - } else if summary.is_empty() { - // Strictly speaking an empty bucket is a valid bucket, but if someone ran the - // scrubber they were likely expecting to scan something, and if we see no timelines - // at all then it's likely due to some configuration issues like a bad prefix - tracing::error!( - "No timelines found in bucket {} prefix {}", - bucket_config.bucket, - bucket_config - .prefix_in_bucket - .unwrap_or("".to_string()) - ); - } - - Ok(()) - } - } + scan_pageserver_metadata_cmd( + bucket_config, + controller_client.as_ref(), + tenant_ids, + json, + post_to_storcon, + ) + .await } } Command::FindGarbage { @@ -254,31 +228,14 @@ async fn main() -> anyhow::Result<()> { min_age, mode, } => { - match (&controller_client_conf, mode) { - (Some(_), _) => { - // Any mode may run when controller API is set - } - (None, GcMode::Full) => { - // The part of physical GC where we erase ancestor layers cannot be done safely without - // confirming the most recent complete shard split with the controller. Refuse to run, rather - // than doing it unsafely. - return Err(anyhow!("Full physical GC requires `--controller-api` and `--controller-jwt` to run")); - } - (None, GcMode::DryRun | GcMode::IndicesOnly) => { - // These GcModes do not require the controller to run. - } - } - - let summary = pageserver_physical_gc( - bucket_config, - controller_client_conf, + pageserver_physical_gc_cmd( + &bucket_config, + controller_client.as_ref(), tenant_ids, - min_age.into(), + min_age, mode, ) - .await?; - println!("{}", serde_json::to_string(&summary).unwrap()); - Ok(()) + .await } Command::FindLargeObjects { min_size, @@ -295,5 +252,142 @@ async fn main() -> anyhow::Result<()> { println!("{}", serde_json::to_string(&summary).unwrap()); Ok(()) } + Command::CronJob { + gc_min_age, + gc_mode, + post_to_storcon, + } => { + run_cron_job( + bucket_config, + controller_client.as_ref(), + gc_min_age, + gc_mode, + post_to_storcon, + ) + .await + } + } +} + +/// Runs the scrubber cron job. +/// 1. Do pageserver physical gc +/// 2. Scan pageserver metadata +pub async fn run_cron_job( + bucket_config: BucketConfig, + controller_client: Option<&control_api::Client>, + gc_min_age: humantime::Duration, + gc_mode: GcMode, + post_to_storcon: bool, +) -> anyhow::Result<()> { + tracing::info!(%gc_min_age, %gc_mode, "Running pageserver-physical-gc"); + pageserver_physical_gc_cmd( + &bucket_config, + controller_client, + Vec::new(), + gc_min_age, + gc_mode, + ) + .await?; + tracing::info!(%post_to_storcon, node_kind = %NodeKind::Pageserver, "Running scan-metadata"); + scan_pageserver_metadata_cmd( + bucket_config, + controller_client, + Vec::new(), + true, + post_to_storcon, + ) + .await?; + + Ok(()) +} + +pub async fn pageserver_physical_gc_cmd( + bucket_config: &BucketConfig, + controller_client: Option<&control_api::Client>, + tenant_shard_ids: Vec, + min_age: humantime::Duration, + mode: GcMode, +) -> anyhow::Result<()> { + match (controller_client, mode) { + (Some(_), _) => { + // Any mode may run when controller API is set + } + (None, GcMode::Full) => { + // The part of physical GC where we erase ancestor layers cannot be done safely without + // confirming the most recent complete shard split with the controller. Refuse to run, rather + // than doing it unsafely. + return Err(anyhow!( + "Full physical GC requires `--controller-api` and `--controller-jwt` to run" + )); + } + (None, GcMode::DryRun | GcMode::IndicesOnly) => { + // These GcModes do not require the controller to run. + } + } + + let summary = pageserver_physical_gc( + bucket_config, + controller_client, + tenant_shard_ids, + min_age.into(), + mode, + ) + .await?; + println!("{}", serde_json::to_string(&summary).unwrap()); + Ok(()) +} + +pub async fn scan_pageserver_metadata_cmd( + bucket_config: BucketConfig, + controller_client: Option<&control_api::Client>, + tenant_shard_ids: Vec, + json: bool, + post_to_storcon: bool, +) -> anyhow::Result<()> { + if controller_client.is_none() && post_to_storcon { + return Err(anyhow!("Posting pageserver scan health status to storage controller requires `--controller-api` and `--controller-jwt` to run")); + } + match scan_pageserver_metadata(bucket_config.clone(), tenant_shard_ids).await { + Err(e) => { + tracing::error!("Failed: {e}"); + Err(e) + } + Ok(summary) => { + if json { + println!("{}", serde_json::to_string(&summary).unwrap()) + } else { + println!("{}", summary.summary_string()); + } + + if post_to_storcon { + if let Some(client) = controller_client { + let body = summary.build_health_update_request(); + client + .dispatch::( + Method::POST, + "control/v1/metadata_health/update".to_string(), + Some(body), + ) + .await?; + } + } + + if summary.is_fatal() { + tracing::error!("Fatal scrub errors detected"); + } else if summary.is_empty() { + // Strictly speaking an empty bucket is a valid bucket, but if someone ran the + // scrubber they were likely expecting to scan something, and if we see no timelines + // at all then it's likely due to some configuration issues like a bad prefix + tracing::error!( + "No timelines found in bucket {} prefix {}", + bucket_config.bucket, + bucket_config + .prefix_in_bucket + .unwrap_or("".to_string()) + ); + } + + Ok(()) + } } } diff --git a/storage_scrubber/src/metadata_stream.rs b/storage_scrubber/src/metadata_stream.rs index 54812ffc94..10d77937f1 100644 --- a/storage_scrubber/src/metadata_stream.rs +++ b/storage_scrubber/src/metadata_stream.rs @@ -2,7 +2,6 @@ use std::str::FromStr; use anyhow::{anyhow, Context}; use async_stream::{stream, try_stream}; -use aws_sdk_s3::{types::ObjectIdentifier, Client}; use futures::StreamExt; use remote_storage::{GenericRemoteStorage, ListingMode, ListingObject, RemotePath}; use tokio_stream::Stream; @@ -15,7 +14,7 @@ use pageserver_api::shard::TenantShardId; use utils::id::{TenantId, TimelineId}; /// Given a remote storage and a target, output a stream of TenantIds discovered via listing prefixes -pub fn stream_tenants_generic<'a>( +pub fn stream_tenants<'a>( remote_client: &'a GenericRemoteStorage, target: &'a RootTarget, ) -> impl Stream> + 'a { @@ -36,92 +35,36 @@ pub fn stream_tenants_generic<'a>( } } -/// Given an S3 bucket, output a stream of TenantIds discovered via ListObjectsv2 -pub fn stream_tenants<'a>( - s3_client: &'a Client, - target: &'a RootTarget, -) -> impl Stream> + 'a { - try_stream! { - let mut continuation_token = None; - let tenants_target = target.tenants_root(); - loop { - let fetch_response = - list_objects_with_retries(s3_client, &tenants_target, continuation_token.clone()).await?; - - let new_entry_ids = fetch_response - .common_prefixes() - .iter() - .filter_map(|prefix| prefix.prefix()) - .filter_map(|prefix| -> Option<&str> { - prefix - .strip_prefix(&tenants_target.prefix_in_bucket)? - .strip_suffix('/') - }).map(|entry_id_str| { - entry_id_str - .parse() - .with_context(|| format!("Incorrect entry id str: {entry_id_str}")) - }); - - for i in new_entry_ids { - yield i?; - } - - match fetch_response.next_continuation_token { - Some(new_token) => continuation_token = Some(new_token), - None => break, - } - } - } -} - pub async fn stream_tenant_shards<'a>( - s3_client: &'a Client, + remote_client: &'a GenericRemoteStorage, target: &'a RootTarget, tenant_id: TenantId, ) -> anyhow::Result> + 'a> { - let mut tenant_shard_ids: Vec> = Vec::new(); - let mut continuation_token = None; let shards_target = target.tenant_shards_prefix(&tenant_id); - loop { - tracing::info!("Listing in {}", shards_target.prefix_in_bucket); - let fetch_response = - list_objects_with_retries(s3_client, &shards_target, continuation_token.clone()).await; - let fetch_response = match fetch_response { - Err(e) => { - tenant_shard_ids.push(Err(e)); - break; - } - Ok(r) => r, - }; + let strip_prefix = target.tenants_root().prefix_in_bucket; + let prefix_str = &strip_prefix.strip_prefix("/").unwrap_or(&strip_prefix); - let new_entry_ids = fetch_response - .common_prefixes() - .iter() - .filter_map(|prefix| prefix.prefix()) - .filter_map(|prefix| -> Option<&str> { - prefix - .strip_prefix(&target.tenants_root().prefix_in_bucket)? - .strip_suffix('/') - }) - .map(|entry_id_str| { - let first_part = entry_id_str.split('/').next().unwrap(); + tracing::info!("Listing shards in {}", shards_target.prefix_in_bucket); + let listing = + list_objects_with_retries(remote_client, ListingMode::WithDelimiter, &shards_target) + .await?; - first_part - .parse::() - .with_context(|| format!("Incorrect entry id str: {first_part}")) - }); + let tenant_shard_ids = listing + .prefixes + .iter() + .map(|prefix| prefix.get_path().as_str()) + .filter_map(|prefix| -> Option<&str> { prefix.strip_prefix(prefix_str) }) + .map(|entry_id_str| { + let first_part = entry_id_str.split('/').next().unwrap(); - for i in new_entry_ids { - tenant_shard_ids.push(i); - } - - match fetch_response.next_continuation_token { - Some(new_token) => continuation_token = Some(new_token), - None => break, - } - } + first_part + .parse::() + .with_context(|| format!("Incorrect entry id str: {first_part}")) + }) + .collect::>(); + tracing::debug!("Yielding {} shards for {tenant_id}", tenant_shard_ids.len()); Ok(stream! { for i in tenant_shard_ids { let id = i?; @@ -130,69 +73,10 @@ pub async fn stream_tenant_shards<'a>( }) } -/// Given a TenantShardId, output a stream of the timelines within that tenant, discovered -/// using ListObjectsv2. The listing is done before the stream is built, so that this -/// function can be used to generate concurrency on a stream using buffer_unordered. -pub async fn stream_tenant_timelines<'a>( - s3_client: &'a Client, - target: &'a RootTarget, - tenant: TenantShardId, -) -> anyhow::Result> + 'a> { - let mut timeline_ids: Vec> = Vec::new(); - let mut continuation_token = None; - let timelines_target = target.timelines_root(&tenant); - - loop { - tracing::debug!("Listing in {}", tenant); - let fetch_response = - list_objects_with_retries(s3_client, &timelines_target, continuation_token.clone()) - .await; - let fetch_response = match fetch_response { - Err(e) => { - timeline_ids.push(Err(e)); - break; - } - Ok(r) => r, - }; - - let new_entry_ids = fetch_response - .common_prefixes() - .iter() - .filter_map(|prefix| prefix.prefix()) - .filter_map(|prefix| -> Option<&str> { - prefix - .strip_prefix(&timelines_target.prefix_in_bucket)? - .strip_suffix('/') - }) - .map(|entry_id_str| { - entry_id_str - .parse::() - .with_context(|| format!("Incorrect entry id str: {entry_id_str}")) - }); - - for i in new_entry_ids { - timeline_ids.push(i); - } - - match fetch_response.next_continuation_token { - Some(new_token) => continuation_token = Some(new_token), - None => break, - } - } - - tracing::debug!("Yielding for {}", tenant); - Ok(stream! { - for i in timeline_ids { - let id = i?; - yield Ok(TenantShardTimelineId::new(tenant, id)); - } - }) -} - /// Given a `TenantShardId`, output a stream of the timelines within that tenant, discovered /// using a listing. The listing is done before the stream is built, so that this /// function can be used to generate concurrency on a stream using buffer_unordered. -pub async fn stream_tenant_timelines_generic<'a>( +pub async fn stream_tenant_timelines<'a>( remote_client: &'a GenericRemoteStorage, target: &'a RootTarget, tenant: TenantShardId, @@ -200,6 +84,11 @@ pub async fn stream_tenant_timelines_generic<'a>( let mut timeline_ids: Vec> = Vec::new(); let timelines_target = target.timelines_root(&tenant); + let prefix_str = &timelines_target + .prefix_in_bucket + .strip_prefix("/") + .unwrap_or(&timelines_target.prefix_in_bucket); + let mut objects_stream = std::pin::pin!(stream_objects_with_retries( remote_client, ListingMode::WithDelimiter, @@ -220,11 +109,7 @@ pub async fn stream_tenant_timelines_generic<'a>( .prefixes .iter() .filter_map(|prefix| -> Option<&str> { - prefix - .get_path() - .as_str() - .strip_prefix(&timelines_target.prefix_in_bucket)? - .strip_suffix('/') + prefix.get_path().as_str().strip_prefix(prefix_str) }) .map(|entry_id_str| { entry_id_str @@ -237,7 +122,7 @@ pub async fn stream_tenant_timelines_generic<'a>( } } - tracing::debug!("Yielding for {}", tenant); + tracing::debug!("Yielding {} timelines for {}", timeline_ids.len(), tenant); Ok(stream! { for i in timeline_ids { let id = i?; @@ -247,37 +132,6 @@ pub async fn stream_tenant_timelines_generic<'a>( } pub(crate) fn stream_listing<'a>( - s3_client: &'a Client, - target: &'a S3Target, -) -> impl Stream> + 'a { - try_stream! { - let mut continuation_token = None; - loop { - let fetch_response = - list_objects_with_retries(s3_client, target, continuation_token.clone()).await?; - - if target.delimiter.is_empty() { - for object_key in fetch_response.contents().iter().filter_map(|object| object.key()) - { - let object_id = ObjectIdentifier::builder().key(object_key).build()?; - yield object_id; - } - } else { - for prefix in fetch_response.common_prefixes().iter().filter_map(|p| p.prefix()) { - let object_id = ObjectIdentifier::builder().key(prefix).build()?; - yield object_id; - } - } - - match fetch_response.next_continuation_token { - Some(new_token) => continuation_token = Some(new_token), - None => break, - } - } - } -} - -pub(crate) fn stream_listing_generic<'a>( remote_client: &'a GenericRemoteStorage, target: &'a S3Target, ) -> impl Stream)>> + 'a { diff --git a/storage_scrubber/src/pageserver_physical_gc.rs b/storage_scrubber/src/pageserver_physical_gc.rs index c8b1ed49f4..88681e38c2 100644 --- a/storage_scrubber/src/pageserver_physical_gc.rs +++ b/storage_scrubber/src/pageserver_physical_gc.rs @@ -1,13 +1,10 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use crate::checks::{list_timeline_blobs, BlobDataParseResult}; use crate::metadata_stream::{stream_tenant_timelines, stream_tenants}; -use crate::{ - init_remote, BucketConfig, ControllerClientConfig, NodeKind, RootTarget, TenantShardTimelineId, -}; -use aws_sdk_s3::Client; +use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId}; use futures_util::{StreamExt, TryStreamExt}; use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata; use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path}; @@ -15,10 +12,11 @@ use pageserver::tenant::storage_layer::LayerName; use pageserver::tenant::IndexPart; use pageserver_api::controller_api::TenantDescribeResponse; use pageserver_api::shard::{ShardIndex, TenantShardId}; -use remote_storage::RemotePath; +use remote_storage::{GenericRemoteStorage, ListingObject, RemotePath}; use reqwest::Method; use serde::Serialize; use storage_controller_client::control_api; +use tokio_util::sync::CancellationToken; use tracing::{info_span, Instrument}; use utils::generation::Generation; use utils::id::{TenantId, TenantTimelineId}; @@ -242,38 +240,13 @@ impl TenantRefAccumulator { } } -async fn is_old_enough( - s3_client: &Client, - bucket_config: &BucketConfig, - min_age: &Duration, - key: &str, - summary: &mut GcSummary, -) -> bool { +fn is_old_enough(min_age: &Duration, key: &ListingObject, summary: &mut GcSummary) -> bool { // Validation: we will only GC indices & layers after a time threshold (e.g. one week) so that during an incident // it is easier to read old data for analysis, and easier to roll back shard splits without having to un-delete any objects. - let age: Duration = match s3_client - .head_object() - .bucket(&bucket_config.bucket) - .key(key) - .send() - .await - { - Ok(response) => match response.last_modified { - None => { - tracing::warn!("Missing last_modified"); - summary.remote_storage_errors += 1; - return false; - } - Some(last_modified) => match SystemTime::try_from(last_modified).map(|t| t.elapsed()) { - Ok(Ok(e)) => e, - Err(_) | Ok(Err(_)) => { - tracing::warn!("Bad last_modified time: {last_modified:?}"); - return false; - } - }, - }, - Err(e) => { - tracing::warn!("Failed to HEAD {key}: {e}"); + let age = match key.last_modified.elapsed() { + Ok(e) => e, + Err(_) => { + tracing::warn!("Bad last_modified time: {:?}", key.last_modified); summary.remote_storage_errors += 1; return false; } @@ -291,17 +264,30 @@ async fn is_old_enough( old_enough } +/// Same as [`is_old_enough`], but doesn't require a [`ListingObject`] passed to it. +async fn check_is_old_enough( + remote_client: &GenericRemoteStorage, + key: &RemotePath, + min_age: &Duration, + summary: &mut GcSummary, +) -> Option { + let listing_object = remote_client + .head_object(key, &CancellationToken::new()) + .await + .ok()?; + Some(is_old_enough(min_age, &listing_object, summary)) +} + async fn maybe_delete_index( - s3_client: &Client, - bucket_config: &BucketConfig, + remote_client: &GenericRemoteStorage, min_age: &Duration, latest_gen: Generation, - key: &str, + obj: &ListingObject, mode: GcMode, summary: &mut GcSummary, ) { // Validation: we will only delete things that parse cleanly - let basename = key.rsplit_once('/').unwrap().1; + let basename = obj.key.get_path().file_name().unwrap(); let candidate_generation = match parse_remote_index_path(RemotePath::from_string(basename).unwrap()) { Some(g) => g, @@ -330,7 +316,7 @@ async fn maybe_delete_index( return; } - if !is_old_enough(s3_client, bucket_config, min_age, key, summary).await { + if !is_old_enough(min_age, obj, summary) { return; } @@ -340,11 +326,8 @@ async fn maybe_delete_index( } // All validations passed: erase the object - match s3_client - .delete_object() - .bucket(&bucket_config.bucket) - .key(key) - .send() + match remote_client + .delete(&obj.key, &CancellationToken::new()) .await { Ok(_) => { @@ -360,8 +343,7 @@ async fn maybe_delete_index( #[allow(clippy::too_many_arguments)] async fn gc_ancestor( - s3_client: &Client, - bucket_config: &BucketConfig, + remote_client: &GenericRemoteStorage, root_target: &RootTarget, min_age: &Duration, ancestor: TenantShardId, @@ -370,7 +352,7 @@ async fn gc_ancestor( summary: &mut GcSummary, ) -> anyhow::Result<()> { // Scan timelines in the ancestor - let timelines = stream_tenant_timelines(s3_client, root_target, ancestor).await?; + let timelines = stream_tenant_timelines(remote_client, root_target, ancestor).await?; let mut timelines = std::pin::pin!(timelines); // Build a list of keys to retain @@ -378,7 +360,7 @@ async fn gc_ancestor( while let Some(ttid) = timelines.next().await { let ttid = ttid?; - let data = list_timeline_blobs(s3_client, ttid, root_target).await?; + let data = list_timeline_blobs(remote_client, ttid, root_target).await?; let s3_layers = match data.blob_data { BlobDataParseResult::Parsed { @@ -429,7 +411,8 @@ async fn gc_ancestor( // We apply a time threshold to GCing objects that are un-referenced: this preserves our ability // to roll back a shard split if we have to, by avoiding deleting ancestor layers right away - if !is_old_enough(s3_client, bucket_config, min_age, &key, summary).await { + let path = RemotePath::from_string(key.strip_prefix("/").unwrap_or(&key)).unwrap(); + if check_is_old_enough(remote_client, &path, min_age, summary).await != Some(true) { continue; } @@ -439,13 +422,7 @@ async fn gc_ancestor( } // All validations passed: erase the object - match s3_client - .delete_object() - .bucket(&bucket_config.bucket) - .key(&key) - .send() - .await - { + match remote_client.delete(&path, &CancellationToken::new()).await { Ok(_) => { tracing::info!("Successfully deleted unreferenced ancestor layer {key}"); summary.ancestor_layers_deleted += 1; @@ -473,16 +450,16 @@ async fn gc_ancestor( /// This type of GC is not necessary for correctness: rather it serves to reduce wasted storage capacity, and /// make sure that object listings don't get slowed down by large numbers of garbage objects. pub async fn pageserver_physical_gc( - bucket_config: BucketConfig, - controller_client_conf: Option, + bucket_config: &BucketConfig, + controller_client: Option<&control_api::Client>, tenant_shard_ids: Vec, min_age: Duration, mode: GcMode, ) -> anyhow::Result { - let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?; + let (remote_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?; let tenants = if tenant_shard_ids.is_empty() { - futures::future::Either::Left(stream_tenants(&s3_client, &target)) + futures::future::Either::Left(stream_tenants(&remote_client, &target)) } else { futures::future::Either::Right(futures::stream::iter(tenant_shard_ids.into_iter().map(Ok))) }; @@ -495,14 +472,13 @@ pub async fn pageserver_physical_gc( let accumulator = Arc::new(std::sync::Mutex::new(TenantRefAccumulator::default())); // Generate a stream of TenantTimelineId - let timelines = tenants.map_ok(|t| stream_tenant_timelines(&s3_client, &target, t)); + let timelines = tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, t)); let timelines = timelines.try_buffered(CONCURRENCY); let timelines = timelines.try_flatten(); // Generate a stream of S3TimelineBlobData async fn gc_timeline( - s3_client: &Client, - bucket_config: &BucketConfig, + remote_client: &GenericRemoteStorage, min_age: &Duration, target: &RootTarget, mode: GcMode, @@ -510,7 +486,7 @@ pub async fn pageserver_physical_gc( accumulator: &Arc>, ) -> anyhow::Result { let mut summary = GcSummary::default(); - let data = list_timeline_blobs(s3_client, ttid, target).await?; + let data = list_timeline_blobs(remote_client, ttid, target).await?; let (index_part, latest_gen, candidates) = match &data.blob_data { BlobDataParseResult::Parsed { @@ -535,17 +511,9 @@ pub async fn pageserver_physical_gc( accumulator.lock().unwrap().update(ttid, index_part); for key in candidates { - maybe_delete_index( - s3_client, - bucket_config, - min_age, - latest_gen, - &key, - mode, - &mut summary, - ) - .instrument(info_span!("maybe_delete_index", %ttid, ?latest_gen, key)) - .await; + maybe_delete_index(remote_client, min_age, latest_gen, &key, mode, &mut summary) + .instrument(info_span!("maybe_delete_index", %ttid, ?latest_gen, %key.key)) + .await; } Ok(summary) @@ -556,15 +524,7 @@ pub async fn pageserver_physical_gc( // Drain futures for per-shard GC, populating accumulator as a side effect { let timelines = timelines.map_ok(|ttid| { - gc_timeline( - &s3_client, - &bucket_config, - &min_age, - &target, - mode, - ttid, - &accumulator, - ) + gc_timeline(&remote_client, &min_age, &target, mode, ttid, &accumulator) }); let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY)); @@ -574,7 +534,7 @@ pub async fn pageserver_physical_gc( } // Execute cross-shard GC, using the accumulator's full view of all the shards built in the per-shard GC - let Some(controller_client) = controller_client_conf.map(|c| c.build_client()) else { + let Some(client) = controller_client else { tracing::info!("Skipping ancestor layer GC, because no `--controller-api` was specified"); return Ok(summary); }; @@ -583,13 +543,12 @@ pub async fn pageserver_physical_gc( .unwrap() .into_inner() .unwrap() - .into_gc_ancestors(&controller_client, &mut summary) + .into_gc_ancestors(client, &mut summary) .await; for ancestor_shard in ancestor_shards { gc_ancestor( - &s3_client, - &bucket_config, + &remote_client, &target, &min_age, ancestor_shard, diff --git a/storage_scrubber/src/scan_pageserver_metadata.rs b/storage_scrubber/src/scan_pageserver_metadata.rs index b9630056e1..151ef27672 100644 --- a/storage_scrubber/src/scan_pageserver_metadata.rs +++ b/storage_scrubber/src/scan_pageserver_metadata.rs @@ -1,16 +1,16 @@ use std::collections::{HashMap, HashSet}; use crate::checks::{ - branch_cleanup_and_check_errors, list_timeline_blobs, BlobDataParseResult, S3TimelineBlobData, - TenantObjectListing, TimelineAnalysis, + branch_cleanup_and_check_errors, list_timeline_blobs, BlobDataParseResult, + RemoteTimelineBlobData, TenantObjectListing, TimelineAnalysis, }; use crate::metadata_stream::{stream_tenant_timelines, stream_tenants}; use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId}; -use aws_sdk_s3::Client; use futures_util::{StreamExt, TryStreamExt}; use pageserver::tenant::remote_timeline_client::remote_layer_path; use pageserver_api::controller_api::MetadataHealthUpdateRequest; use pageserver_api::shard::TenantShardId; +use remote_storage::GenericRemoteStorage; use serde::Serialize; use utils::id::TenantId; use utils::shard::ShardCount; @@ -36,7 +36,7 @@ impl MetadataSummary { Self::default() } - fn update_data(&mut self, data: &S3TimelineBlobData) { + fn update_data(&mut self, data: &RemoteTimelineBlobData) { self.timeline_shard_count += 1; if let BlobDataParseResult::Parsed { index_part, @@ -116,14 +116,14 @@ Index versions: {version_summary} } /// Scan the pageserver metadata in an S3 bucket, reporting errors and statistics. -pub async fn scan_metadata( +pub async fn scan_pageserver_metadata( bucket_config: BucketConfig, tenant_ids: Vec, ) -> anyhow::Result { - let (s3_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?; + let (remote_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?; let tenants = if tenant_ids.is_empty() { - futures::future::Either::Left(stream_tenants(&s3_client, &target)) + futures::future::Either::Left(stream_tenants(&remote_client, &target)) } else { futures::future::Either::Right(futures::stream::iter(tenant_ids.into_iter().map(Ok))) }; @@ -133,20 +133,20 @@ pub async fn scan_metadata( const CONCURRENCY: usize = 32; // Generate a stream of TenantTimelineId - let timelines = tenants.map_ok(|t| stream_tenant_timelines(&s3_client, &target, t)); + let timelines = tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, t)); let timelines = timelines.try_buffered(CONCURRENCY); let timelines = timelines.try_flatten(); // Generate a stream of S3TimelineBlobData async fn report_on_timeline( - s3_client: &Client, + remote_client: &GenericRemoteStorage, target: &RootTarget, ttid: TenantShardTimelineId, - ) -> anyhow::Result<(TenantShardTimelineId, S3TimelineBlobData)> { - let data = list_timeline_blobs(s3_client, ttid, target).await?; + ) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> { + let data = list_timeline_blobs(remote_client, ttid, target).await?; Ok((ttid, data)) } - let timelines = timelines.map_ok(|ttid| report_on_timeline(&s3_client, &target, ttid)); + let timelines = timelines.map_ok(|ttid| report_on_timeline(&remote_client, &target, ttid)); let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY)); // We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different @@ -157,12 +157,11 @@ pub async fn scan_metadata( let mut tenant_timeline_results = Vec::new(); async fn analyze_tenant( - s3_client: &Client, - target: &RootTarget, + remote_client: &GenericRemoteStorage, tenant_id: TenantId, summary: &mut MetadataSummary, mut tenant_objects: TenantObjectListing, - timelines: Vec<(TenantShardTimelineId, S3TimelineBlobData)>, + timelines: Vec<(TenantShardTimelineId, RemoteTimelineBlobData)>, highest_shard_count: ShardCount, ) { summary.tenant_count += 1; @@ -191,8 +190,7 @@ pub async fn scan_metadata( // Apply checks to this timeline shard's metadata, and in the process update `tenant_objects` // reference counts for layers across the tenant. let analysis = branch_cleanup_and_check_errors( - s3_client, - target, + remote_client, &ttid, &mut tenant_objects, None, @@ -273,8 +271,7 @@ pub async fn scan_metadata( let tenant_objects = std::mem::take(&mut tenant_objects); let timelines = std::mem::take(&mut tenant_timeline_results); analyze_tenant( - &s3_client, - &target, + &remote_client, prev_tenant_id, &mut summary, tenant_objects, @@ -311,8 +308,7 @@ pub async fn scan_metadata( if !tenant_timeline_results.is_empty() { analyze_tenant( - &s3_client, - &target, + &remote_client, tenant_id.expect("Must be set if results are present"), &mut summary, tenant_objects, diff --git a/storage_scrubber/src/scan_safekeeper_metadata.rs b/storage_scrubber/src/scan_safekeeper_metadata.rs index 08a4541c5c..1a9f3d0ef5 100644 --- a/storage_scrubber/src/scan_safekeeper_metadata.rs +++ b/storage_scrubber/src/scan_safekeeper_metadata.rs @@ -14,9 +14,8 @@ use utils::{ }; use crate::{ - cloud_admin_api::CloudAdminApiClient, init_remote_generic, - metadata_stream::stream_listing_generic, BucketConfig, ConsoleConfig, NodeKind, RootTarget, - TenantShardTimelineId, + cloud_admin_api::CloudAdminApiClient, init_remote, metadata_stream::stream_listing, + BucketConfig, ConsoleConfig, NodeKind, RootTarget, TenantShardTimelineId, }; /// Generally we should ask safekeepers, but so far we use everywhere default 16MB. @@ -107,7 +106,7 @@ pub async fn scan_safekeeper_metadata( let timelines = client.query(&query, &[]).await?; info!("loaded {} timelines", timelines.len()); - let (remote_client, target) = init_remote_generic(bucket_config, NodeKind::Safekeeper).await?; + let (remote_client, target) = init_remote(bucket_config, NodeKind::Safekeeper).await?; let console_config = ConsoleConfig::from_env()?; let cloud_admin_api_client = CloudAdminApiClient::new(console_config); @@ -188,14 +187,19 @@ async fn check_timeline( // we need files, so unset it. timeline_dir_target.delimiter = String::new(); - let mut stream = std::pin::pin!(stream_listing_generic(remote_client, &timeline_dir_target)); + let prefix_str = &timeline_dir_target + .prefix_in_bucket + .strip_prefix("/") + .unwrap_or(&timeline_dir_target.prefix_in_bucket); + + let mut stream = std::pin::pin!(stream_listing(remote_client, &timeline_dir_target)); while let Some(obj) = stream.next().await { let (key, _obj) = obj?; let seg_name = key .get_path() .as_str() - .strip_prefix(&timeline_dir_target.prefix_in_bucket) + .strip_prefix(prefix_str) .expect("failed to extract segment name"); expected_segfiles.remove(seg_name); } diff --git a/storage_scrubber/src/tenant_snapshot.rs b/storage_scrubber/src/tenant_snapshot.rs index 1866e6ec80..bb4079b5f4 100644 --- a/storage_scrubber/src/tenant_snapshot.rs +++ b/storage_scrubber/src/tenant_snapshot.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::checks::{list_timeline_blobs, BlobDataParseResult, S3TimelineBlobData}; +use crate::checks::{list_timeline_blobs, BlobDataParseResult, RemoteTimelineBlobData}; use crate::metadata_stream::{stream_tenant_shards, stream_tenant_timelines}; use crate::{ - download_object_to_file, init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, + download_object_to_file_s3, init_remote, init_remote_s3, BucketConfig, NodeKind, RootTarget, + TenantShardTimelineId, }; use anyhow::Context; use async_stream::stream; @@ -15,6 +16,7 @@ use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata; use pageserver::tenant::storage_layer::LayerName; use pageserver::tenant::IndexPart; use pageserver_api::shard::TenantShardId; +use remote_storage::GenericRemoteStorage; use utils::generation::Generation; use utils::id::TenantId; @@ -34,7 +36,8 @@ impl SnapshotDownloader { output_path: Utf8PathBuf, concurrency: usize, ) -> anyhow::Result { - let (s3_client, s3_root) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?; + let (s3_client, s3_root) = + init_remote_s3(bucket_config.clone(), NodeKind::Pageserver).await?; Ok(Self { s3_client, s3_root, @@ -91,7 +94,7 @@ impl SnapshotDownloader { let Some(version) = versions.versions.as_ref().and_then(|v| v.first()) else { return Err(anyhow::anyhow!("No versions found for {remote_layer_path}")); }; - download_object_to_file( + download_object_to_file_s3( &self.s3_client, &self.bucket_config.bucket, &remote_layer_path, @@ -215,11 +218,11 @@ impl SnapshotDownloader { } pub async fn download(&self) -> anyhow::Result<()> { - let (s3_client, target) = + let (remote_client, target) = init_remote(self.bucket_config.clone(), NodeKind::Pageserver).await?; // Generate a stream of TenantShardId - let shards = stream_tenant_shards(&s3_client, &target, self.tenant_id).await?; + let shards = stream_tenant_shards(&remote_client, &target, self.tenant_id).await?; let shards: Vec = shards.try_collect().await?; // Only read from shards that have the highest count: avoids redundantly downloading @@ -237,18 +240,19 @@ impl SnapshotDownloader { for shard in shards.into_iter().filter(|s| s.shard_count == shard_count) { // Generate a stream of TenantTimelineId - let timelines = stream_tenant_timelines(&s3_client, &self.s3_root, shard).await?; + let timelines = stream_tenant_timelines(&remote_client, &target, shard).await?; // Generate a stream of S3TimelineBlobData async fn load_timeline_index( - s3_client: &Client, + remote_client: &GenericRemoteStorage, target: &RootTarget, ttid: TenantShardTimelineId, - ) -> anyhow::Result<(TenantShardTimelineId, S3TimelineBlobData)> { - let data = list_timeline_blobs(s3_client, ttid, target).await?; + ) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> { + let data = list_timeline_blobs(remote_client, ttid, target).await?; Ok((ttid, data)) } - let timelines = timelines.map_ok(|ttid| load_timeline_index(&s3_client, &target, ttid)); + let timelines = + timelines.map_ok(|ttid| load_timeline_index(&remote_client, &target, ttid)); let mut timelines = std::pin::pin!(timelines.try_buffered(8)); while let Some(i) = timelines.next().await { @@ -278,7 +282,7 @@ impl SnapshotDownloader { for (ttid, layers) in ancestor_layers.into_iter() { tracing::info!( - "Downloading {} layers from ancvestor timeline {ttid}...", + "Downloading {} layers from ancestor timeline {ttid}...", layers.len() ); diff --git a/test_runner/README.md b/test_runner/README.md index e2f26a19ce..73aa29d4bb 100644 --- a/test_runner/README.md +++ b/test_runner/README.md @@ -71,8 +71,7 @@ a subdirectory for each version with naming convention `v{PG_VERSION}/`. Inside that dir, a `bin/postgres` binary should be present. `DEFAULT_PG_VERSION`: The version of Postgres to use, This is used to construct full path to the postgres binaries. -Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION="14"`. Alternatively, -you can use `--pg-version` argument. +Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16` `TEST_OUTPUT`: Set the directory where test state and test output files should go. `TEST_SHARED_FIXTURES`: Try to re-use a single pageserver for all the tests. diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index b63dfd4e47..8eda19d1e2 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -1,7 +1,8 @@ import random from dataclasses import dataclass +from enum import Enum from functools import total_ordering -from typing import Any, Type, TypeVar, Union +from typing import Any, Dict, Type, TypeVar, Union T = TypeVar("T", bound="Id") @@ -147,6 +148,19 @@ class TimelineId(Id): return self.id.hex() +@dataclass +class TenantTimelineId: + tenant_id: TenantId + timeline_id: TimelineId + + @classmethod + def from_json(cls, d: Dict[str, Any]) -> "TenantTimelineId": + return TenantTimelineId( + tenant_id=TenantId(d["tenant_id"]), + timeline_id=TimelineId(d["timeline_id"]), + ) + + # Workaround for compat with python 3.9, which does not have `typing.Self` TTenantShardId = TypeVar("TTenantShardId", bound="TenantShardId") @@ -200,3 +214,9 @@ class TenantShardId: def __hash__(self) -> int: return hash(self._tuple()) + + +# TODO: Replace with `StrEnum` when we upgrade to python 3.11 +class TimelineArchivalState(str, Enum): + ARCHIVED = "Archived" + UNARCHIVED = "Unarchived" diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ec5a83601e..92febfec9b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -61,8 +61,6 @@ from fixtures.pageserver.common_types import IndexPartDump, LayerName, parse_lay from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import ( wait_for_last_record_lsn, - wait_for_upload, - wait_for_upload_queue_empty, ) from fixtures.pg_version import PgVersion from fixtures.port_distributor import PortDistributor @@ -1162,7 +1160,6 @@ class NeonEnv: "listen_http_addr": f"localhost:{pageserver_port.http}", "pg_auth_type": pg_auth_type, "http_auth_type": http_auth_type, - "image_compression": "zstd", } if self.pageserver_virtual_file_io_engine is not None: ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine @@ -2287,7 +2284,7 @@ class NeonStorageController(MetricsGetter, LogUtils): self.allowed_errors, ) - def pageserver_api(self) -> PageserverHttpClient: + def pageserver_api(self, *args, **kwargs) -> PageserverHttpClient: """ The storage controller implements a subset of the pageserver REST API, for mapping per-tenant actions into per-shard actions (e.g. timeline creation). Tests should invoke those @@ -2296,7 +2293,7 @@ class NeonStorageController(MetricsGetter, LogUtils): auth_token = None if self.auth_enabled: auth_token = self.env.auth_keys.generate_token(scope=TokenScope.PAGE_SERVER_API) - return PageserverHttpClient(self.port, lambda: True, auth_token) + return PageserverHttpClient(self.port, lambda: True, auth_token, *args, **kwargs) def request(self, method, *args, **kwargs) -> requests.Response: resp = requests.request(method, *args, **kwargs) @@ -4644,6 +4641,7 @@ class StorageScrubber: ] args = base_args + args + log.info(f"Invoking scrubber command {args} with env: {env}") (output_path, stdout, status_code) = subprocess_capture( self.log_dir, args, @@ -5347,9 +5345,7 @@ def last_flush_lsn_upload( for tenant_shard_id, pageserver in shards: ps_http = pageserver.http_client(auth_token=auth_token) wait_for_last_record_lsn(ps_http, tenant_shard_id, timeline_id, last_flush_lsn) - # force a checkpoint to trigger upload - ps_http.timeline_checkpoint(tenant_shard_id, timeline_id) - wait_for_upload(ps_http, tenant_shard_id, timeline_id, last_flush_lsn) + ps_http.timeline_checkpoint(tenant_shard_id, timeline_id, wait_until_uploaded=True) return last_flush_lsn @@ -5434,9 +5430,5 @@ def generate_uploads_and_deletions( # ensures that the pageserver is in a fully idle state: there will be no more # background ingest, no more uploads pending, and therefore no non-determinism # in subsequent actions like pageserver restarts. - final_lsn = flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id, pageserver.id) - ps_http.timeline_checkpoint(tenant_id, timeline_id) - # Finish uploads - wait_for_upload(ps_http, tenant_id, timeline_id, final_lsn) - # Finish all remote writes (including deletions) - wait_for_upload_queue_empty(ps_http, tenant_id, timeline_id) + flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id, pageserver.id) + ps_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True) diff --git a/test_runner/fixtures/pageserver/allowed_errors.py b/test_runner/fixtures/pageserver/allowed_errors.py index dff002bd4b..f8d9a51c91 100755 --- a/test_runner/fixtures/pageserver/allowed_errors.py +++ b/test_runner/fixtures/pageserver/allowed_errors.py @@ -52,9 +52,6 @@ DEFAULT_PAGESERVER_ALLOWED_ERRORS = ( ".*Error processing HTTP request: Forbidden", # intentional failpoints ".*failpoint ", - # FIXME: These need investigation - ".*manual_gc.*is_shutdown_requested\\(\\) called in an unexpected task or thread.*", - ".*tenant_list: timeline is not found in remote index while it is present in the tenants registry.*", # Tenant::delete_timeline() can cause any of the four following errors. # FIXME: we shouldn't be considering it an error: https://github.com/neondatabase/neon/issues/2946 ".*could not flush frozen layer.*queue is in state Stopped", # when schedule layer upload fails because queued got closed before compaction got killed diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index cd4261f1b8..582f9c0264 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -10,7 +10,7 @@ import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId +from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.pg_version import PgVersion @@ -621,6 +621,22 @@ class PageserverHttpClient(requests.Session, MetricsGetter): ) self.verbose_error(res) + def timeline_archival_config( + self, + tenant_id: Union[TenantId, TenantShardId], + timeline_id: TimelineId, + state: TimelineArchivalState, + ): + config = {"state": state.value} + log.info( + f"requesting timeline archival config {config} for tenant {tenant_id} and timeline {timeline_id}" + ) + res = self.post( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/archival_config", + json=config, + ) + self.verbose_error(res) + def timeline_get_lsn_by_timestamp( self, tenant_id: Union[TenantId, TenantShardId], diff --git a/test_runner/fixtures/pageserver/utils.py b/test_runner/fixtures/pageserver/utils.py index b75a480a63..a74fef6a60 100644 --- a/test_runner/fixtures/pageserver/utils.py +++ b/test_runner/fixtures/pageserver/utils.py @@ -430,12 +430,17 @@ def enable_remote_storage_versioning( return response -MANY_SMALL_LAYERS_TENANT_CONFIG = { - "gc_period": "0s", - "compaction_period": "0s", - "checkpoint_distance": 1024**2, - "image_creation_threshold": 100, -} +def many_small_layers_tenant_config() -> Dict[str, Any]: + """ + Create a new dict to avoid issues with deleting from the global value. + In python, the global is mutable. + """ + return { + "gc_period": "0s", + "compaction_period": "0s", + "checkpoint_distance": 1024**2, + "image_creation_threshold": 100, + } def poll_for_remote_storage_iterations(remote_storage_kind: RemoteStorageKind) -> int: diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 0227285822..92c98763e3 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -1,6 +1,7 @@ import os from typing import Any, Dict, Optional +import allure import pytest import toml from _pytest.python import Metafunc @@ -91,3 +92,23 @@ def pytest_generate_tests(metafunc: Metafunc): and (platform := os.getenv("PLATFORM")) is not None ): metafunc.parametrize("platform", [platform.lower()]) + + +@pytest.hookimpl(hookwrapper=True, tryfirst=True) +def pytest_runtest_makereport(*args, **kwargs): + # Add test parameters to Allue report to distinguish the same tests with different parameters. + # Names has `__` prefix to avoid conflicts with `pytest.mark.parametrize` parameters + + # A mapping between `uname -m` and `RUNNER_ARCH` values. + # `RUNNER_ARCH` environment variable is set on GitHub Runners, + # possible values are X86, X64, ARM, or ARM64. + # See https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables + uname_m = { + "aarch64": "ARM64", + "arm64": "ARM64", + "x86_64": "X64", + }.get(os.uname().machine, "UNKNOWN") + arch = os.getenv("RUNNER_ARCH", uname_m) + allure.dynamic.parameter("__arch", arch) + + yield diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 941889a2f5..e12c8e5f4a 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -3,8 +3,6 @@ import os from typing import Optional import pytest -from _pytest.config import Config -from _pytest.config.argparsing import Parser """ This fixture is used to determine which version of Postgres to use for tests. @@ -52,7 +50,7 @@ class PgVersion(str, enum.Enum): return None -DEFAULT_VERSION: PgVersion = PgVersion.V15 +DEFAULT_VERSION: PgVersion = PgVersion.V16 def skip_on_postgres(version: PgVersion, reason: str): @@ -69,22 +67,8 @@ def xfail_on_postgres(version: PgVersion, reason: str): ) -def pytest_addoption(parser: Parser): - parser.addoption( - "--pg-version", - action="store", - type=PgVersion, - help="DEPRECATED: Postgres version to use for tests", - ) - - def run_only_on_default_postgres(reason: str): return pytest.mark.skipif( PgVersion(os.environ.get("DEFAULT_PG_VERSION", DEFAULT_VERSION)) is not DEFAULT_VERSION, reason=reason, ) - - -def pytest_configure(config: Config): - if config.getoption("--pg-version"): - raise Exception("--pg-version is deprecated, use DEFAULT_PG_VERSION env var instead") diff --git a/test_runner/fixtures/safekeeper/http.py b/test_runner/fixtures/safekeeper/http.py index a51b89744b..05b43cfb72 100644 --- a/test_runner/fixtures/safekeeper/http.py +++ b/test_runner/fixtures/safekeeper/http.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import pytest import requests -from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.common_types import Lsn, TenantId, TenantTimelineId, TimelineId from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics @@ -65,6 +65,16 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): def check_status(self): self.get(f"http://localhost:{self.port}/v1/status").raise_for_status() + def get_metrics_str(self) -> str: + """You probably want to use get_metrics() instead.""" + request_result = self.get(f"http://localhost:{self.port}/metrics") + request_result.raise_for_status() + return request_result.text + + def get_metrics(self) -> SafekeeperMetrics: + res = self.get_metrics_str() + return SafekeeperMetrics(parse_metrics(res)) + def is_testing_enabled_or_skip(self): if not self.is_testing_enabled: pytest.skip("safekeeper was built without 'testing' feature") @@ -89,60 +99,18 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert res_json is None return res_json - def debug_dump(self, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - params = params or {} - res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) - res.raise_for_status() - res_json = json.loads(res.text) - assert isinstance(res_json, dict) - return res_json - - def patch_control_file( - self, - tenant_id: TenantId, - timeline_id: TimelineId, - patch: Dict[str, Any], - ) -> Dict[str, Any]: - res = self.patch( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file", - json={ - "updates": patch, - "apply_fields": list(patch.keys()), - }, - ) + def tenant_delete_force(self, tenant_id: TenantId) -> Dict[Any, Any]: + res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) return res_json - def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]: - res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body) + def timeline_list(self) -> List[TenantTimelineId]: + res = self.get(f"http://localhost:{self.port}/v1/tenant/timeline") res.raise_for_status() - res_json = res.json() - assert isinstance(res_json, dict) - return res_json - - def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: Dict[str, Any]): - res = self.post( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/copy", - json=body, - ) - res.raise_for_status() - - def timeline_digest( - self, tenant_id: TenantId, timeline_id: TimelineId, from_lsn: Lsn, until_lsn: Lsn - ) -> Dict[str, Any]: - res = self.get( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/digest", - params={ - "from_lsn": str(from_lsn), - "until_lsn": str(until_lsn), - }, - ) - res.raise_for_status() - res_json = res.json() - assert isinstance(res_json, dict) - return res_json + resj = res.json() + return [TenantTimelineId.from_json(ttidj) for ttidj in resj] def timeline_create( self, @@ -183,20 +151,6 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn: return self.timeline_status(tenant_id, timeline_id).commit_lsn - def record_safekeeper_info(self, tenant_id: TenantId, timeline_id: TimelineId, body): - res = self.post( - f"http://localhost:{self.port}/v1/record_safekeeper_info/{tenant_id}/{timeline_id}", - json=body, - ) - res.raise_for_status() - - def checkpoint(self, tenant_id: TenantId, timeline_id: TimelineId): - res = self.post( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint", - json={}, - ) - res.raise_for_status() - # only_local doesn't remove segments in the remote storage. def timeline_delete( self, tenant_id: TenantId, timeline_id: TimelineId, only_local: bool = False @@ -212,19 +166,71 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def tenant_delete_force(self, tenant_id: TenantId) -> Dict[Any, Any]: - res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") + def debug_dump(self, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + params = params or {} + res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) + res.raise_for_status() + res_json = json.loads(res.text) + assert isinstance(res_json, dict) + return res_json + + def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]: + res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body) res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) return res_json - def get_metrics_str(self) -> str: - """You probably want to use get_metrics() instead.""" - request_result = self.get(f"http://localhost:{self.port}/metrics") - request_result.raise_for_status() - return request_result.text + def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: Dict[str, Any]): + res = self.post( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/copy", + json=body, + ) + res.raise_for_status() - def get_metrics(self) -> SafekeeperMetrics: - res = self.get_metrics_str() - return SafekeeperMetrics(parse_metrics(res)) + def patch_control_file( + self, + tenant_id: TenantId, + timeline_id: TimelineId, + patch: Dict[str, Any], + ) -> Dict[str, Any]: + res = self.patch( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file", + json={ + "updates": patch, + "apply_fields": list(patch.keys()), + }, + ) + res.raise_for_status() + res_json = res.json() + assert isinstance(res_json, dict) + return res_json + + def checkpoint(self, tenant_id: TenantId, timeline_id: TimelineId): + res = self.post( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint", + json={}, + ) + res.raise_for_status() + + def timeline_digest( + self, tenant_id: TenantId, timeline_id: TimelineId, from_lsn: Lsn, until_lsn: Lsn + ) -> Dict[str, Any]: + res = self.get( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/digest", + params={ + "from_lsn": str(from_lsn), + "until_lsn": str(until_lsn), + }, + ) + res.raise_for_status() + res_json = res.json() + assert isinstance(res_json, dict) + return res_json + + def record_safekeeper_info(self, tenant_id: TenantId, timeline_id: TimelineId, body): + res = self.post( + f"http://localhost:{self.port}/v1/record_safekeeper_info/{tenant_id}/{timeline_id}", + json=body, + ) + res.raise_for_status() diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index cc93762175..065a78bf9b 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -10,7 +10,7 @@ from fixtures.neon_fixtures import ( tenant_get_shards, wait_for_last_flush_lsn, ) -from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload +from fixtures.pageserver.utils import wait_for_last_record_lsn # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. @@ -174,8 +174,9 @@ class Workload: if upload: # Wait for written data to be uploaded to S3 (force a checkpoint to trigger upload) - ps_http.timeline_checkpoint(tenant_shard_id, self.timeline_id) - wait_for_upload(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn) + ps_http.timeline_checkpoint( + tenant_shard_id, self.timeline_id, wait_until_uploaded=True + ) log.info(f"Churn: waiting for remote LSN {last_flush_lsn}") else: log.info(f"Churn: not waiting for upload, disk LSN {last_flush_lsn}") diff --git a/test_runner/performance/README.md b/test_runner/performance/README.md index 7ad65821d4..70d75a6dcf 100644 --- a/test_runner/performance/README.md +++ b/test_runner/performance/README.md @@ -7,7 +7,7 @@ easier to see if you have compile errors without scrolling up. You may also need to run `./scripts/pysync`. Then run the tests -`DEFAULT_PG_VERSION=15 NEON_BIN=./target/release poetry run pytest test_runner/performance` +`DEFAULT_PG_VERSION=16 NEON_BIN=./target/release poetry run pytest test_runner/performance` Some handy pytest flags for local development: - `-x` tells pytest to stop on first error diff --git a/test_runner/performance/pageserver/README.md b/test_runner/performance/pageserver/README.md index fdd09cd946..56ffad9963 100644 --- a/test_runner/performance/pageserver/README.md +++ b/test_runner/performance/pageserver/README.md @@ -11,6 +11,6 @@ It supports mounting snapshots using overlayfs, which improves iteration time. Here's a full command line. ``` -RUST_BACKTRACE=1 NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS=1 DEFAULT_PG_VERSION=15 BUILD_TYPE=release \ +RUST_BACKTRACE=1 NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS=1 DEFAULT_PG_VERSION=16 BUILD_TYPE=release \ ./scripts/pytest test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py ```` diff --git a/test_runner/performance/pageserver/interactive/test_many_small_tenants.py b/test_runner/performance/pageserver/interactive/test_many_small_tenants.py index 33848b06d3..8d781c1609 100644 --- a/test_runner/performance/pageserver/interactive/test_many_small_tenants.py +++ b/test_runner/performance/pageserver/interactive/test_many_small_tenants.py @@ -14,7 +14,7 @@ from performance.pageserver.util import ensure_pageserver_ready_for_benchmarking """ Usage: -DEFAULT_PG_VERSION=15 BUILD_TYPE=debug NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS=1 INTERACTIVE=true \ +DEFAULT_PG_VERSION=16 BUILD_TYPE=debug NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS=1 INTERACTIVE=true \ ./scripts/pytest --timeout 0 test_runner/performance/pageserver/interactive/test_many_small_tenants.py """ diff --git a/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py b/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py index 0348b08f04..9ad6e7907c 100644 --- a/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py +++ b/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py @@ -5,8 +5,12 @@ from typing import Any, Dict, Tuple import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, PgBin, wait_for_last_flush_lsn -from fixtures.pageserver.utils import wait_for_upload_queue_empty +from fixtures.neon_fixtures import ( + NeonEnv, + NeonEnvBuilder, + PgBin, + flush_ep_to_pageserver, +) from fixtures.remote_storage import s3_storage from fixtures.utils import humantime_to_ms @@ -62,9 +66,6 @@ def test_download_churn( run_benchmark(env, pg_bin, record, io_engine, concurrency_per_target, duration) - # see https://github.com/neondatabase/neon/issues/8712 - env.stop(immediate=True) - def setup_env(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin): remote_storage_kind = s3_storage() @@ -98,9 +99,9 @@ def setup_env(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin): f"INSERT INTO data SELECT lpad(i::text, {bytes_per_row}, '0') FROM generate_series(1, {int(nrows)}) as i", options="-c statement_timeout=0", ) - wait_for_last_flush_lsn(env, ep, tenant_id, timeline_id) - # TODO: this is a bit imprecise, there could be frozen layers being written out that we don't observe here - wait_for_upload_queue_empty(client, tenant_id, timeline_id) + flush_ep_to_pageserver(env, ep, tenant_id, timeline_id) + + client.timeline_checkpoint(tenant_id, timeline_id, compact=False, wait_until_uploaded=True) return env diff --git a/test_runner/performance/test_layer_map.py b/test_runner/performance/test_layer_map.py index 890b70b9fc..bc6d9de346 100644 --- a/test_runner/performance/test_layer_map.py +++ b/test_runner/performance/test_layer_map.py @@ -1,20 +1,21 @@ import time -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver -# -# Benchmark searching the layer map, when there are a lot of small layer files. -# def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark): - env = neon_env_builder.init_start() + """Benchmark searching the layer map, when there are a lot of small layer files.""" + + env = neon_env_builder.init_configs() n_iters = 10 n_records = 100000 + env.start() + # We want to have a lot of lot of layer files to exercise the layer map. Disable # GC, and make checkpoint_distance very small, so that we get a lot of small layer # files. - tenant, _ = env.neon_cli.create_tenant( + tenant, timeline = env.neon_cli.create_tenant( conf={ "gc_period": "0s", "checkpoint_distance": "16384", @@ -24,8 +25,7 @@ def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark): } ) - env.neon_cli.create_timeline("test_layer_map", tenant_id=tenant) - endpoint = env.endpoints.create_start("test_layer_map", tenant_id=tenant) + endpoint = env.endpoints.create_start("main", tenant_id=tenant) cur = endpoint.connect().cursor() cur.execute("create table t(x integer)") for _ in range(n_iters): @@ -33,9 +33,12 @@ def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark): time.sleep(1) cur.execute("vacuum t") + with zenbenchmark.record_duration("test_query"): cur.execute("SELECT count(*) from t") assert cur.fetchone() == (n_iters * n_records,) - # see https://github.com/neondatabase/neon/issues/8712 - env.stop(immediate=True) + flush_ep_to_pageserver(env, endpoint, tenant, timeline) + env.pageserver.http_client().timeline_checkpoint( + tenant, timeline, compact=False, wait_until_uploaded=True + ) diff --git a/test_runner/performance/test_logical_replication.py b/test_runner/performance/test_logical_replication.py index c4e42a7834..077f73ac06 100644 --- a/test_runner/performance/test_logical_replication.py +++ b/test_runner/performance/test_logical_replication.py @@ -282,15 +282,16 @@ def test_snap_files( env = benchmark_project_pub.pgbench_env connstr = benchmark_project_pub.connstr - pg_bin.run_capture(["pgbench", "-i", "-s100"], env=env) with psycopg2.connect(connstr) as conn: conn.autocommit = True with conn.cursor() as cur: cur.execute("SELECT rolsuper FROM pg_roles WHERE rolname = 'neondb_owner'") - is_super = cur.fetchall()[0] + is_super = cur.fetchall()[0][0] assert is_super, "This benchmark won't work if we don't have superuser" + pg_bin.run_capture(["pgbench", "-i", "-s100"], env=env) + conn = psycopg2.connect(connstr) conn.autocommit = True cur = conn.cursor() diff --git a/test_runner/regress/test_auth.py b/test_runner/regress/test_auth.py index 7cb85e3dd1..780c0e1602 100644 --- a/test_runner/regress/test_auth.py +++ b/test_runner/regress/test_auth.py @@ -211,7 +211,7 @@ def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): def check_pageserver(expect_success: bool, **conn_kwargs): check_connection( env.pageserver, - f"pagestream {env.initial_tenant} {env.initial_timeline}", + f"pagestream_v2 {env.initial_tenant} {env.initial_timeline}", expect_success, **conn_kwargs, ) diff --git a/test_runner/regress/test_combocid.py b/test_runner/regress/test_combocid.py index 6d2567b7ee..41907b1f20 100644 --- a/test_runner/regress/test_combocid.py +++ b/test_runner/regress/test_combocid.py @@ -1,4 +1,4 @@ -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): @@ -34,7 +34,7 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): # Clear the cache, so that we exercise reconstructing the pages # from WAL - cur.execute("SELECT clear_buffer_cache()") + endpoint.clear_shared_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. @@ -43,6 +43,10 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): assert len(rows) == 500 cur.execute("rollback") + flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline) + env.pageserver.http_client().timeline_checkpoint( + env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True + ) def test_combocid_delete(neon_env_builder: NeonEnvBuilder): @@ -92,7 +96,7 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): cur.execute("delete from t") # Clear the cache, so that we exercise reconstructing the pages # from WAL - cur.execute("SELECT clear_buffer_cache()") + endpoint.clear_shared_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. @@ -102,6 +106,11 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): cur.execute("rollback") + flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline) + env.pageserver.http_client().timeline_checkpoint( + env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True + ) + def test_combocid(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() @@ -137,3 +146,8 @@ def test_combocid(neon_env_builder: NeonEnvBuilder): assert cur.rowcount == n_records cur.execute("rollback") + + flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline) + env.pageserver.http_client().timeline_checkpoint( + env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True + ) diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index afa5f6873c..c361efe90a 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -9,14 +9,17 @@ from typing import List, Optional import pytest import toml -from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, PgBin +from fixtures.neon_fixtures import ( + NeonEnv, + NeonEnvBuilder, + PgBin, + flush_ep_to_pageserver, +) from fixtures.pageserver.http import PageserverApiException from fixtures.pageserver.utils import ( timeline_delete_wait_completed, - wait_for_last_record_lsn, - wait_for_upload, ) from fixtures.pg_version import PgVersion from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage @@ -39,7 +42,7 @@ from fixtures.workload import Workload # # How to run `test_backward_compatibility` locally: # -# export DEFAULT_PG_VERSION=15 +# export DEFAULT_PG_VERSION=16 # export BUILD_TYPE=release # export CHECK_ONDISK_DATA_COMPATIBILITY=true # export COMPATIBILITY_SNAPSHOT_DIR=test_output/compatibility_snapshot_pgv${DEFAULT_PG_VERSION} @@ -61,7 +64,7 @@ from fixtures.workload import Workload # # How to run `test_forward_compatibility` locally: # -# export DEFAULT_PG_VERSION=15 +# export DEFAULT_PG_VERSION=16 # export BUILD_TYPE=release # export CHECK_ONDISK_DATA_COMPATIBILITY=true # export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE} @@ -122,11 +125,9 @@ def test_create_snapshot( timeline_id = dict(snapshot_config["branch_name_mappings"]["main"])[tenant_id] pageserver_http = env.pageserver.http_client() - lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0]) - wait_for_last_record_lsn(pageserver_http, tenant_id, timeline_id, lsn) - pageserver_http.timeline_checkpoint(tenant_id, timeline_id) - wait_for_upload(pageserver_http, tenant_id, timeline_id, lsn) + flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id) + pageserver_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True) env.endpoints.stop_all() for sk in env.safekeepers: @@ -300,7 +301,7 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r pg_version = env.pg_version # Stop endpoint while we recreate timeline - ep.stop() + flush_ep_to_pageserver(env, ep, tenant_id, timeline_id) try: pageserver_http.timeline_preserve_initdb_archive(tenant_id, timeline_id) @@ -348,6 +349,11 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r assert not dump_from_wal_differs, "dump from WAL differs" assert not initial_dump_differs, "initial dump differs" + flush_ep_to_pageserver(env, ep, tenant_id, timeline_id) + pageserver_http.timeline_checkpoint( + tenant_id, timeline_id, compact=False, wait_until_uploaded=True + ) + def dump_differs( first: Path, second: Path, output: Path, allowed_diffs: Optional[List[str]] = None diff --git a/test_runner/regress/test_config.py b/test_runner/regress/test_config.py index 4bb7df1e6a..2ef28eb94b 100644 --- a/test_runner/regress/test_config.py +++ b/test_runner/regress/test_config.py @@ -1,6 +1,7 @@ +import os from contextlib import closing -from fixtures.neon_fixtures import NeonEnv +from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder # @@ -28,3 +29,45 @@ def test_config(neon_simple_env: NeonEnv): # check that config change was applied assert cur.fetchone() == ("debug1",) + + +# +# Test that reordering of safekeepers does not restart walproposer +# +def test_safekeepers_reconfigure_reorder( + neon_env_builder: NeonEnvBuilder, +): + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.init_start() + env.neon_cli.create_branch("test_safekeepers_reconfigure_reorder") + + endpoint = env.endpoints.create_start("test_safekeepers_reconfigure_reorder") + + old_sks = "" + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("SHOW neon.safekeepers") + res = cur.fetchone() + assert res is not None, "neon.safekeepers GUC is set" + old_sks = res[0] + + # Reorder safekeepers + safekeepers = endpoint.active_safekeepers + safekeepers = safekeepers[1:] + safekeepers[:1] + + endpoint.reconfigure(safekeepers=safekeepers) + + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("SHOW neon.safekeepers") + res = cur.fetchone() + assert res is not None, "neon.safekeepers GUC is set" + new_sks = res[0] + + assert new_sks != old_sks, "GUC changes were applied" + + log_path = os.path.join(endpoint.endpoint_path(), "compute.log") + with open(log_path, "r") as log_file: + logs = log_file.read() + # Check that walproposer was not restarted + assert "restarting walproposer" not in logs diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 4dae9176b8..4385cfca76 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -18,7 +18,6 @@ from fixtures.neon_fixtures import ( from fixtures.pageserver.utils import ( timeline_delete_wait_completed, wait_for_last_record_lsn, - wait_for_upload, ) from fixtures.remote_storage import RemoteStorageKind from fixtures.utils import assert_pageserver_backups_equal, subprocess_capture @@ -144,7 +143,7 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build # Wait for data to land in s3 wait_for_last_record_lsn(client, tenant, timeline, Lsn(end_lsn)) - wait_for_upload(client, tenant, timeline, Lsn(end_lsn)) + client.timeline_checkpoint(tenant, timeline, compact=False, wait_until_uploaded=True) # Check it worked endpoint = env.endpoints.create_start(branch_name, tenant_id=tenant) @@ -290,7 +289,7 @@ def _import( # Wait for data to land in s3 wait_for_last_record_lsn(client, tenant, timeline, lsn) - wait_for_upload(client, tenant, timeline, lsn) + client.timeline_checkpoint(tenant, timeline, compact=False, wait_until_uploaded=True) # Check it worked endpoint = env.endpoints.create_start(branch_name, tenant_id=tenant, lsn=lsn) diff --git a/test_runner/regress/test_layer_bloating.py b/test_runner/regress/test_layer_bloating.py index 77dc8a35b5..b8126395fd 100644 --- a/test_runner/regress/test_layer_bloating.py +++ b/test_runner/regress/test_layer_bloating.py @@ -1,27 +1,31 @@ import os -import time import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import ( - NeonEnv, + NeonEnvBuilder, logical_replication_sync, wait_for_last_flush_lsn, ) from fixtures.pg_version import PgVersion -def test_layer_bloating(neon_simple_env: NeonEnv, vanilla_pg): - env = neon_simple_env - - if env.pg_version != PgVersion.V16: +def test_layer_bloating(neon_env_builder: NeonEnvBuilder, vanilla_pg): + if neon_env_builder.pg_version != PgVersion.V16: pytest.skip("pg_log_standby_snapshot() function is available only in PG16") - timeline = env.neon_cli.create_branch("test_logical_replication", "empty") - endpoint = env.endpoints.create_start( - "test_logical_replication", config_lines=["log_statement=all"] + env = neon_env_builder.init_start( + initial_tenant_conf={ + "gc_period": "0s", + "compaction_period": "0s", + "compaction_threshold": 99999, + "image_creation_threshold": 99999, + } ) + timeline = env.initial_timeline + endpoint = env.endpoints.create_start("main", config_lines=["log_statement=all"]) + pg_conn = endpoint.connect() cur = pg_conn.cursor() @@ -54,7 +58,7 @@ def test_layer_bloating(neon_simple_env: NeonEnv, vanilla_pg): # Wait logical replication to sync logical_replication_sync(vanilla_pg, endpoint) wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, timeline) - time.sleep(10) + env.pageserver.http_client().timeline_checkpoint(env.initial_tenant, timeline, compact=False) # Check layer file sizes timeline_path = f"{env.pageserver.workdir}/tenants/{env.initial_tenant}/timelines/{timeline}/" @@ -63,3 +67,5 @@ def test_layer_bloating(neon_simple_env: NeonEnv, vanilla_pg): if filename.startswith("00000"): log.info(f"layer {filename} size is {os.path.getsize(timeline_path + filename)}") assert os.path.getsize(timeline_path + filename) < 512_000_000 + + env.stop(immediate=True) diff --git a/test_runner/regress/test_lfc_resize.py b/test_runner/regress/test_lfc_resize.py index 2a3442448a..1b2c7f808f 100644 --- a/test_runner/regress/test_lfc_resize.py +++ b/test_runner/regress/test_lfc_resize.py @@ -1,3 +1,7 @@ +import os +import random +import re +import subprocess import threading import time @@ -17,17 +21,17 @@ def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): "test_lfc_resize", config_lines=[ "neon.file_cache_path='file.cache'", - "neon.max_file_cache_size=1GB", - "neon.file_cache_size_limit=1GB", + "neon.max_file_cache_size=512MB", + "neon.file_cache_size_limit=512MB", ], ) n_resize = 10 - scale = 10 + scale = 100 def run_pgbench(connstr: str): log.info(f"Start a pgbench workload on pg {connstr}") pg_bin.run_capture(["pgbench", "-i", f"-s{scale}", connstr]) - pg_bin.run_capture(["pgbench", "-c4", f"-T{n_resize}", "-Mprepared", connstr]) + pg_bin.run_capture(["pgbench", "-c10", f"-T{n_resize}", "-Mprepared", "-S", connstr]) thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True) thread.start() @@ -35,9 +39,21 @@ def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): conn = endpoint.connect() cur = conn.cursor() - for i in range(n_resize): - cur.execute(f"alter system set neon.file_cache_size_limit='{i*10}MB'") + for _ in range(n_resize): + size = random.randint(1, 512) + cur.execute(f"alter system set neon.file_cache_size_limit='{size}MB'") cur.execute("select pg_reload_conf()") time.sleep(1) + cur.execute("alter system set neon.file_cache_size_limit='100MB'") + cur.execute("select pg_reload_conf()") + thread.join() + + lfc_file_path = f"{endpoint.pg_data_dir_path()}/file.cache" + lfc_file_size = os.path.getsize(lfc_file_path) + res = subprocess.run(["ls", "-sk", lfc_file_path], check=True, text=True, capture_output=True) + lfc_file_blocks = re.findall("([0-9A-F]+)", res.stdout)[0] + log.info(f"Size of LFC file {lfc_file_size}, blocks {lfc_file_blocks}") + assert lfc_file_size <= 512 * 1024 * 1024 + assert int(lfc_file_blocks) <= 128 * 1024 diff --git a/test_runner/regress/test_logical_replication.py b/test_runner/regress/test_logical_replication.py index 0d18aa43b7..f83a833dda 100644 --- a/test_runner/regress/test_logical_replication.py +++ b/test_runner/regress/test_logical_replication.py @@ -22,7 +22,7 @@ def random_string(n: int): @pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.V2, AuxFileStore.CrossValidation] + "pageserver_aux_file_policy", [AuxFileStore.V2, AuxFileStore.CrossValidation] ) def test_aux_file_v2_flag(neon_simple_env: NeonEnv, pageserver_aux_file_policy: AuxFileStore): env = neon_simple_env @@ -31,9 +31,7 @@ def test_aux_file_v2_flag(neon_simple_env: NeonEnv, pageserver_aux_file_policy: assert pageserver_aux_file_policy == tenant_config["switch_aux_file_policy"] -@pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.CrossValidation] -) +@pytest.mark.parametrize("pageserver_aux_file_policy", [AuxFileStore.CrossValidation]) def test_logical_replication(neon_simple_env: NeonEnv, vanilla_pg): env = neon_simple_env @@ -175,9 +173,7 @@ COMMIT; # Test that neon.logical_replication_max_snap_files works -@pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.CrossValidation] -) +@pytest.mark.parametrize("pageserver_aux_file_policy", [AuxFileStore.CrossValidation]) def test_obsolete_slot_drop(neon_simple_env: NeonEnv, vanilla_pg): def slot_removed(ep): assert ( @@ -355,9 +351,7 @@ FROM generate_series(1, 16384) AS seq; -- Inserts enough rows to exceed 16MB of # # Most pages start with a contrecord, so we don't do anything special # to ensure that. -@pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.CrossValidation] -) +@pytest.mark.parametrize("pageserver_aux_file_policy", [AuxFileStore.CrossValidation]) def test_restart_endpoint(neon_simple_env: NeonEnv, vanilla_pg): env = neon_simple_env @@ -402,9 +396,7 @@ def test_restart_endpoint(neon_simple_env: NeonEnv, vanilla_pg): # logical replication bug as such, but without logical replication, # records passed ot the WAL redo process are never large enough to hit # the bug. -@pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.CrossValidation] -) +@pytest.mark.parametrize("pageserver_aux_file_policy", [AuxFileStore.CrossValidation]) def test_large_records(neon_simple_env: NeonEnv, vanilla_pg): env = neon_simple_env @@ -476,9 +468,7 @@ def test_slots_and_branching(neon_simple_env: NeonEnv): ws_cur.execute("select pg_create_logical_replication_slot('my_slot', 'pgoutput')") -@pytest.mark.parametrize( - "pageserver_aux_file_policy", [AuxFileStore.V1, AuxFileStore.CrossValidation] -) +@pytest.mark.parametrize("pageserver_aux_file_policy", [AuxFileStore.CrossValidation]) def test_replication_shutdown(neon_simple_env: NeonEnv): # Ensure Postgres can exit without stuck when a replication job is active + neon extension installed env = neon_simple_env diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index d128c60a99..1ac881553f 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -19,11 +19,6 @@ def test_read_validation(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "test_read_validation", - # Use protocol version 2, because the code that constructs the V1 messages - # assumes that a primary always wants to read the latest version of a page, - # and therefore doesn't work with the test functions below to read an older - # page version. - config_lines=["neon.protocol_version=2"], ) with closing(endpoint.connect()) as con: @@ -142,11 +137,6 @@ def test_read_validation_neg(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "test_read_validation_neg", - # Use protocol version 2, because the code that constructs the V1 messages - # assumes that a primary always wants to read the latest version of a page, - # and therefore doesn't work with the test functions below to read an older - # page version. - config_lines=["neon.protocol_version=2"], ) with closing(endpoint.connect()) as con: diff --git a/test_runner/regress/test_s3_restore.py b/test_runner/regress/test_s3_restore.py index 9992647e56..c1a80a54bc 100644 --- a/test_runner/regress/test_s3_restore.py +++ b/test_runner/regress/test_s3_restore.py @@ -8,9 +8,9 @@ from fixtures.neon_fixtures import ( PgBin, ) from fixtures.pageserver.utils import ( - MANY_SMALL_LAYERS_TENANT_CONFIG, assert_prefix_empty, enable_remote_storage_versioning, + many_small_layers_tenant_config, wait_for_upload, ) from fixtures.remote_storage import RemoteStorageKind, s3_storage @@ -33,7 +33,7 @@ def test_tenant_s3_restore( # change it back after initdb, recovery doesn't work if the two # index_part.json uploads happen at same second or too close to each other. - initial_tenant_conf = MANY_SMALL_LAYERS_TENANT_CONFIG + initial_tenant_conf = many_small_layers_tenant_config() del initial_tenant_conf["checkpoint_distance"] env = neon_env_builder.init_start(initial_tenant_conf) @@ -50,7 +50,7 @@ def test_tenant_s3_restore( tenant_id = env.initial_tenant # now lets create the small layers - ps_http.set_tenant_config(tenant_id, MANY_SMALL_LAYERS_TENANT_CONFIG) + ps_http.set_tenant_config(tenant_id, many_small_layers_tenant_config()) # Default tenant and the one we created assert ps_http.get_metric_value("pageserver_tenant_manager_slots", {"mode": "attached"}) == 1 diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 1011a6fd22..bfd82242e9 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -394,6 +394,7 @@ def test_sharding_split_smoke( # Note which pageservers initially hold a shard after tenant creation pre_split_pageserver_ids = [loc["node_id"] for loc in env.storage_controller.locate(tenant_id)] + log.info("Pre-split pageservers: {pre_split_pageserver_ids}") # For pageservers holding a shard, validate their ingest statistics # reflect a proper splitting of the WAL. @@ -555,9 +556,9 @@ def test_sharding_split_smoke( assert sum(total.values()) == split_shard_count * 2 check_effective_tenant_config() - # More specific check: that we are fully balanced. This is deterministic because - # the order in which we consider shards for optimization is deterministic, and the - # order of preference of nodes is also deterministic (lower node IDs win). + # More specific check: that we are fully balanced. It is deterministic that we will get exactly + # one shard on each pageserver, because for these small shards the utilization metric is + # dominated by shard count. log.info(f"total: {total}") assert total == { 1: 1, @@ -577,8 +578,14 @@ def test_sharding_split_smoke( 15: 1, 16: 1, } + + # The controller is not required to lay out the attached locations in any particular way, but + # all the pageservers that originally held an attached shard should still hold one, otherwise + # it would indicate that we had done some unnecessary migration. log.info(f"attached: {attached}") - assert attached == {1: 1, 2: 1, 3: 1, 5: 1, 6: 1, 7: 1, 9: 1, 11: 1} + for ps_id in pre_split_pageserver_ids: + log.info("Pre-split pageserver {ps_id} should still hold an attached location") + assert ps_id in attached # Ensure post-split pageserver locations survive a restart (i.e. the child shards # correctly wrote config to disk, and the storage controller responds correctly diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 95c35e9641..03eb7628be 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -21,13 +21,13 @@ from fixtures.neon_fixtures import ( TokenScope, last_flush_lsn_upload, ) -from fixtures.pageserver.http import PageserverHttpClient +from fixtures.pageserver.http import PageserverApiException, PageserverHttpClient from fixtures.pageserver.utils import ( - MANY_SMALL_LAYERS_TENANT_CONFIG, assert_prefix_empty, assert_prefix_not_empty, enable_remote_storage_versioning, list_prefix, + many_small_layers_tenant_config, remote_storage_delete_key, timeline_delete_wait_completed, ) @@ -41,6 +41,7 @@ from mypy_boto3_s3.type_defs import ( ObjectTypeDef, ) from pytest_httpserver import HTTPServer +from urllib3 import Retry from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response @@ -654,7 +655,7 @@ def test_storage_controller_s3_time_travel_recovery( tenant_id, shard_count=2, shard_stripe_size=8192, - tenant_config=MANY_SMALL_LAYERS_TENANT_CONFIG, + tenant_config=many_small_layers_tenant_config(), ) # Check that the consistency check passes @@ -2144,6 +2145,8 @@ def test_storage_controller_leadership_transfer( port_distributor: PortDistributor, step_down_times_out: bool, ): + neon_env_builder.auth_enabled = True + neon_env_builder.num_pageservers = 3 neon_env_builder.storage_controller_config = { @@ -2264,3 +2267,66 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB # allow for small delay between actually having cancelled and being able reconfigure again wait_until(4, 0.5, reconfigure_node_again) + + +def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder): + """ + The storage controller is meant to handle the case where a timeline CRUD operation races + with a generation-incrementing change to the tenant: this should trigger a retry so that + the operation lands on the highest-generation'd tenant location. + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.start() + tenant_id = TenantId.generate() + env.storage_controller.tenant_create(tenant_id) + + # Set up a failpoint so that a timeline creation will be very slow + failpoint = "timeline-creation-after-uninit" + for ps in env.pageservers: + ps.http_client().configure_failpoints((failpoint, "sleep(10000)")) + + # Start a timeline creation in the background + create_timeline_id = TimelineId.generate() + futs = [] + with concurrent.futures.ThreadPoolExecutor( + max_workers=2 + len(env.pageservers) + len(env.safekeepers) + ) as executor: + futs.append( + executor.submit( + env.storage_controller.pageserver_api( + retries=Retry( + status=0, + connect=0, # Disable retries: we want to see the 503 + ) + ).timeline_create, + PgVersion.NOT_SET, + tenant_id, + create_timeline_id, + ) + ) + + def has_hit_failpoint(): + assert any( + ps.log_contains(f"at failpoint {failpoint}") is not None for ps in env.pageservers + ) + + wait_until(10, 1, has_hit_failpoint) + + # Migrate the tenant while the timeline creation is in progress: this migration will complete once it + # can detach from the old pageserver, which will happen once the failpoint completes. + env.storage_controller.tenant_shard_migrate( + TenantShardId(tenant_id, 0, 0), env.pageservers[1].id + ) + + with pytest.raises(PageserverApiException, match="Tenant attachment changed, please retry"): + futs[0].result(timeout=20) + + # Timeline creation should work when there isn't a concurrent migration, even though it's + # slow (our failpoint is still enabled) + env.storage_controller.pageserver_api( + retries=Retry( + status=0, + connect=0, # Disable retries: we want to see the 503 + ) + ).timeline_create(PgVersion.NOT_SET, tenant_id, create_timeline_id) diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index 2844d1b1d2..292a9a1010 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -152,6 +152,9 @@ def test_scrubber_physical_gc(neon_env_builder: NeonEnvBuilder, shard_count: Opt # This write includes remote upload, will generate an index in this generation workload.write_rows(1) + # We will use a min_age_secs=1 threshold for deletion, let it pass + time.sleep(2) + # With a high min_age, the scrubber should decline to delete anything gc_summary = env.storage_scrubber.pageserver_physical_gc(min_age_secs=3600) assert gc_summary["remote_storage_errors"] == 0 diff --git a/test_runner/regress/test_subscriber_restart.py b/test_runner/regress/test_subscriber_restart.py index 4581008022..91caad7220 100644 --- a/test_runner/regress/test_subscriber_restart.py +++ b/test_runner/regress/test_subscriber_restart.py @@ -37,9 +37,7 @@ def test_subscriber_restart(neon_simple_env: NeonEnv): scur.execute("CREATE TABLE t (pk integer primary key, sk integer)") # scur.execute("CREATE INDEX on t(sk)") # slowdown applying WAL at replica pub_conn = f"host=localhost port={pub.pg_port} dbname=postgres user=cloud_admin" - # synchronous_commit=on to test a hypothesis for why this test has been flaky. - # XXX: Add link to the issue - query = f"CREATE SUBSCRIPTION sub CONNECTION '{pub_conn}' PUBLICATION pub with (synchronous_commit=on)" + query = f"CREATE SUBSCRIPTION sub CONNECTION '{pub_conn}' PUBLICATION pub" scur.execute(query) time.sleep(2) # let initial table sync complete diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index dadf5ca672..448a28dc31 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -9,9 +9,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.http import PageserverApiException from fixtures.pageserver.utils import ( - MANY_SMALL_LAYERS_TENANT_CONFIG, assert_prefix_empty, assert_prefix_not_empty, + many_small_layers_tenant_config, wait_for_upload, ) from fixtures.remote_storage import RemoteStorageKind, s3_storage @@ -76,7 +76,7 @@ def test_tenant_delete_smoke( env.neon_cli.create_tenant( tenant_id=tenant_id, - conf=MANY_SMALL_LAYERS_TENANT_CONFIG, + conf=many_small_layers_tenant_config(), ) # Default tenant and the one we created @@ -215,7 +215,7 @@ def test_tenant_delete_races_timeline_creation(neon_env_builder: NeonEnvBuilder) # (and there is no way to reconstruct the used remote storage kind) remote_storage_kind = RemoteStorageKind.MOCK_S3 neon_env_builder.enable_pageserver_remote_storage(remote_storage_kind) - env = neon_env_builder.init_start(initial_tenant_conf=MANY_SMALL_LAYERS_TENANT_CONFIG) + env = neon_env_builder.init_start(initial_tenant_conf=many_small_layers_tenant_config()) ps_http = env.pageserver.http_client() tenant_id = env.initial_tenant @@ -330,7 +330,7 @@ def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder) remote_storage_kind = RemoteStorageKind.MOCK_S3 neon_env_builder.enable_pageserver_remote_storage(remote_storage_kind) - env = neon_env_builder.init_start(initial_tenant_conf=MANY_SMALL_LAYERS_TENANT_CONFIG) + env = neon_env_builder.init_start(initial_tenant_conf=many_small_layers_tenant_config()) ps_http = env.pageserver.http_client() # create a tenant separate from the main tenant so that we have one remaining diff --git a/test_runner/regress/test_tenant_size.py b/test_runner/regress/test_tenant_size.py index b1ade77a14..f872116a1c 100644 --- a/test_runner/regress/test_tenant_size.py +++ b/test_runner/regress/test_tenant_size.py @@ -757,6 +757,9 @@ def test_lsn_lease_size(neon_env_builder: NeonEnvBuilder, test_output_dir: Path, assert_size_approx_equal_for_lease_test(lease_res, ro_branch_res) + # we are writing a lot, and flushing all of that to disk is not important for this test + env.stop(immediate=True) + def insert_with_action( env: NeonEnv, diff --git a/test_runner/regress/test_timeline_archive.py b/test_runner/regress/test_timeline_archive.py new file mode 100644 index 0000000000..b774c7c9fe --- /dev/null +++ b/test_runner/regress/test_timeline_archive.py @@ -0,0 +1,96 @@ +import pytest +from fixtures.common_types import TenantId, TimelineArchivalState, TimelineId +from fixtures.neon_fixtures import ( + NeonEnv, +) +from fixtures.pageserver.http import PageserverApiException + + +def test_timeline_archive(neon_simple_env: NeonEnv): + env = neon_simple_env + + env.pageserver.allowed_errors.extend( + [ + ".*Timeline .* was not found.*", + ".*timeline not found.*", + ".*Cannot archive timeline which has unarchived child timelines.*", + ".*Precondition failed: Requested tenant is missing.*", + ] + ) + + ps_http = env.pageserver.http_client() + + # first try to archive non existing timeline + # for existing tenant: + invalid_timeline_id = TimelineId.generate() + with pytest.raises(PageserverApiException, match="timeline not found") as exc: + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=invalid_timeline_id, + state=TimelineArchivalState.ARCHIVED, + ) + + assert exc.value.status_code == 404 + + # for non existing tenant: + invalid_tenant_id = TenantId.generate() + with pytest.raises( + PageserverApiException, + match=f"NotFound: tenant {invalid_tenant_id}", + ) as exc: + ps_http.timeline_archival_config( + tenant_id=invalid_tenant_id, + timeline_id=invalid_timeline_id, + state=TimelineArchivalState.ARCHIVED, + ) + + assert exc.value.status_code == 404 + + # construct pair of branches to validate that pageserver prohibits + # archival of ancestor timelines when they have non-archived child branches + parent_timeline_id = env.neon_cli.create_branch("test_ancestor_branch_archive_parent", "empty") + + leaf_timeline_id = env.neon_cli.create_branch( + "test_ancestor_branch_archive_branch1", "test_ancestor_branch_archive_parent" + ) + + timeline_path = env.pageserver.timeline_dir(env.initial_tenant, parent_timeline_id) + + with pytest.raises( + PageserverApiException, + match="Cannot archive timeline which has non-archived child timelines", + ) as exc: + assert timeline_path.exists() + + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=parent_timeline_id, + state=TimelineArchivalState.ARCHIVED, + ) + + assert exc.value.status_code == 412 + + # Test timeline_detail + leaf_detail = ps_http.timeline_detail( + tenant_id=env.initial_tenant, + timeline_id=leaf_timeline_id, + ) + assert leaf_detail["is_archived"] is False + + # Test that archiving the leaf timeline and then the parent works + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=leaf_timeline_id, + state=TimelineArchivalState.ARCHIVED, + ) + leaf_detail = ps_http.timeline_detail( + tenant_id=env.initial_tenant, + timeline_id=leaf_timeline_id, + ) + assert leaf_detail["is_archived"] is True + + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=parent_timeline_id, + state=TimelineArchivalState.ARCHIVED, + ) diff --git a/test_runner/regress/test_timeline_delete.py b/test_runner/regress/test_timeline_delete.py index 6d96dda391..328131cd08 100644 --- a/test_runner/regress/test_timeline_delete.py +++ b/test_runner/regress/test_timeline_delete.py @@ -16,9 +16,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.http import PageserverApiException from fixtures.pageserver.utils import ( - MANY_SMALL_LAYERS_TENANT_CONFIG, assert_prefix_empty, assert_prefix_not_empty, + many_small_layers_tenant_config, poll_for_remote_storage_iterations, timeline_delete_wait_completed, wait_for_last_record_lsn, @@ -782,7 +782,7 @@ def test_timeline_delete_resumed_on_attach( remote_storage_kind = s3_storage() neon_env_builder.enable_pageserver_remote_storage(remote_storage_kind) - env = neon_env_builder.init_start(initial_tenant_conf=MANY_SMALL_LAYERS_TENANT_CONFIG) + env = neon_env_builder.init_start(initial_tenant_conf=many_small_layers_tenant_config()) tenant_id = env.initial_tenant diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index 82fc26126d..d152d0f41f 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -639,8 +639,12 @@ def test_timeline_ancestor_detach_errors(neon_env_builder: NeonEnvBuilder, shard for ps in pageservers.values(): ps.allowed_errors.extend(SHUTDOWN_ALLOWED_ERRORS) - ps.allowed_errors.append( - ".* WARN .* path=/v1/tenant/.*/timeline/.*/detach_ancestor request_id=.*: request was dropped before completing" + ps.allowed_errors.extend( + [ + ".* WARN .* path=/v1/tenant/.*/timeline/.*/detach_ancestor request_id=.*: request was dropped before completing", + # rare error logging, which is hard to reproduce without instrumenting responding with random sleep + '.* ERROR .* path=/v1/tenant/.*/timeline/.*/detach_ancestor request_id=.*: Cancelled request finished with an error: Conflict\\("no ancestors"\\)', + ] ) client = ( diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 5d3b263936..19df834b81 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -254,6 +254,10 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder): assert max(init_m[2].flush_lsns) <= min(final_m[2].flush_lsns) < middle_lsn assert max(init_m[2].commit_lsns) <= min(final_m[2].commit_lsns) < middle_lsn + # Test timeline_list endpoint. + http_cli = env.safekeepers[0].http_client() + assert len(http_cli.timeline_list()) == 3 + # Check that dead minority doesn't prevent the commits: execute insert n_inserts # times, with fault_probability chance of getting a wal acceptor down or up @@ -1296,6 +1300,8 @@ def test_lagging_sk(neon_env_builder: NeonEnvBuilder): # Check that WALs are the same. cmp_sk_wal([sk1, sk2, sk3], tenant_id, timeline_id) + env.stop(immediate=True) + # Smaller version of test_one_sk_down testing peer recovery in isolation: that # it works without compute at all. diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 3fd7a45f8a..b6910406e2 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 3fd7a45f8aae85c080df6329e3c85887b7f3a737 +Subproject commit b6910406e2d05a2c94baa2e530ec882733047759 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 46b4b235f3..76063bff63 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 46b4b235f38413ab5974bb22c022f9b829257674 +Subproject commit 76063bff638ccce7afa99fc9037ac51338b9823d diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 47a9122a5a..8efa089aa7 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 47a9122a5a150a3217fafd3f3d4fe8e020ea718a +Subproject commit 8efa089aa7786381543a4f9efc69b92d43eab8c0 diff --git a/vendor/revisions.json b/vendor/revisions.json index 6e3e489b5d..50cc99c2f1 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,14 +1,14 @@ { "v16": [ - "16.3", - "47a9122a5a150a3217fafd3f3d4fe8e020ea718a" + "16.4", + "8efa089aa7786381543a4f9efc69b92d43eab8c0" ], "v15": [ - "15.7", - "46b4b235f38413ab5974bb22c022f9b829257674" + "15.8", + "76063bff638ccce7afa99fc9037ac51338b9823d" ], "v14": [ - "14.12", - "3fd7a45f8aae85c080df6329e3c85887b7f3a737" + "14.13", + "b6910406e2d05a2c94baa2e530ec882733047759" ] } diff --git a/vm-image-spec.yaml b/vm-image-spec.yaml index 41d6e11725..c94f95f447 100644 --- a/vm-image-spec.yaml +++ b/vm-image-spec.yaml @@ -259,7 +259,7 @@ files: from (values ('5m'),('15m'),('1h')) as t (x); - - metric_name: current_lsn + - metric_name: compute_current_lsn type: gauge help: 'Current LSN of the database' key_labels: @@ -272,6 +272,19 @@ files: else (pg_current_wal_lsn() - '0/0')::FLOAT8 end as lsn; + - metric_name: compute_receive_lsn + type: gauge + help: 'Returns the last write-ahead log location that has been received and synced to disk by streaming replication' + key_labels: + values: [lsn] + query: | + SELECT + CASE + WHEN pg_catalog.pg_is_in_recovery() + THEN (pg_last_wal_receive_lsn() - '0/0')::FLOAT8 + ELSE 0 + END AS lsn; + - metric_name: replication_delay_bytes type: gauge help: 'Bytes between received and replayed LSN' @@ -312,6 +325,20 @@ files: query: | SELECT checkpoints_timed FROM pg_stat_bgwriter; + - metric_name: compute_logical_snapshot_files + type: gauge + help: 'Number of snapshot files in pg_logical/snapshot' + key_labels: + - timeline_id + values: [num_logical_snapshot_files] + query: | + SELECT + (SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id, + -- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. These + -- temporary snapshot files are renamed to the actual snapshot files after they are + -- completely built. We only WAL-log the completely built snapshot files. + (SELECT COUNT(*) FROM pg_ls_logicalsnapdir() WHERE name LIKE '%.snap') AS num_logical_snapshot_files; + # In all the below metrics, we cast LSNs to floats because Prometheus only supports floats. # It's probably fine because float64 can store integers from -2^53 to +2^53 exactly. @@ -327,6 +354,17 @@ files: from pg_replication_slots where slot_type = 'logical'; + - metric_name: compute_subscriptions_count + type: gauge + help: 'Number of logical replication subscriptions grouped by enabled/disabled' + key_labels: + - enabled + values: [subscriptions_count] + query: | + select subenabled::text as enabled, count(*) as subscriptions_count + from pg_subscription + group by subenabled; + - metric_name: retained_wal type: gauge help: 'Retained WAL in inactive replication slots' diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 2d9b372654..20693ad63d 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -80,8 +80,6 @@ time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } tokio-rustls = { version = "0.24" } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } -toml_edit = { version = "0.19", features = ["serde"] } tonic = { version = "0.9", features = ["tls-roots"] } tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] } tracing = { version = "0.1", features = ["log"] } @@ -124,7 +122,6 @@ serde = { version = "1", features = ["alloc", "derive"] } syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit"] } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } zstd = { version = "0.13" } zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] } zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }