From f7516df6c155162aa2d935adadf95524379e0a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Wed, 7 Feb 2024 12:56:53 +0100 Subject: [PATCH 01/81] Pass timestamp as a datetime (#6656) This saves some repetition. I did this in #6533 for `tenant_time_travel_remote_storage` already. --- test_runner/fixtures/pageserver/http.py | 4 ++-- test_runner/regress/test_lsn_mapping.py | 16 ++++------------ 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 92e5027a9f..adea9ca764 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -563,13 +563,13 @@ class PageserverHttpClient(requests.Session): self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, - timestamp, + timestamp: datetime, ): log.info( f"Requesting lsn by timestamp {timestamp}, tenant {tenant_id}, timeline {timeline_id}" ) res = self.get( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_lsn_by_timestamp?timestamp={timestamp}", + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_lsn_by_timestamp?timestamp={timestamp.isoformat()}Z", ) self.verbose_error(res) res_json = res.json() diff --git a/test_runner/regress/test_lsn_mapping.py b/test_runner/regress/test_lsn_mapping.py index 9788e8c0d7..50d7c74af0 100644 --- a/test_runner/regress/test_lsn_mapping.py +++ b/test_runner/regress/test_lsn_mapping.py @@ -64,18 +64,14 @@ def test_lsn_mapping(neon_env_builder: NeonEnvBuilder): # Check edge cases # Timestamp is in the future probe_timestamp = tbl[-1][1] + timedelta(hours=1) - result = client.timeline_get_lsn_by_timestamp( - tenant_id, timeline_id, f"{probe_timestamp.isoformat()}Z" - ) + result = client.timeline_get_lsn_by_timestamp(tenant_id, timeline_id, probe_timestamp) assert result["kind"] == "future" # make sure that we return a well advanced lsn here assert Lsn(result["lsn"]) > start_lsn # Timestamp is in the unreachable past probe_timestamp = tbl[0][1] - timedelta(hours=10) - result = client.timeline_get_lsn_by_timestamp( - tenant_id, timeline_id, f"{probe_timestamp.isoformat()}Z" - ) + result = client.timeline_get_lsn_by_timestamp(tenant_id, timeline_id, probe_timestamp) assert result["kind"] == "past" # make sure that we return the minimum lsn here at the start of the range assert Lsn(result["lsn"]) < start_lsn @@ -83,9 +79,7 @@ def test_lsn_mapping(neon_env_builder: NeonEnvBuilder): # Probe a bunch of timestamps in the valid range for i in range(1, len(tbl), 100): probe_timestamp = tbl[i][1] - result = client.timeline_get_lsn_by_timestamp( - tenant_id, timeline_id, f"{probe_timestamp.isoformat()}Z" - ) + result = client.timeline_get_lsn_by_timestamp(tenant_id, timeline_id, probe_timestamp) assert result["kind"] not in ["past", "nodata"] lsn = result["lsn"] # Call get_lsn_by_timestamp to get the LSN @@ -108,9 +102,7 @@ def test_lsn_mapping(neon_env_builder: NeonEnvBuilder): # Timestamp is in the unreachable past probe_timestamp = tbl[0][1] - timedelta(hours=10) - result = client.timeline_get_lsn_by_timestamp( - tenant_id, timeline_id_child, f"{probe_timestamp.isoformat()}Z" - ) + result = client.timeline_get_lsn_by_timestamp(tenant_id, timeline_id_child, probe_timestamp) assert result["kind"] == "past" # make sure that we return the minimum lsn here at the start of the range assert Lsn(result["lsn"]) >= last_flush_lsn From 3d4fe205ba260c6cd878bf8d0c19623d45920e4f Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 7 Feb 2024 13:08:09 +0000 Subject: [PATCH 02/81] control_plane/attachment_service: database connection pool (#6622) ## Problem This is mainly to limit our concurrency, rather than to speed up requests (I was doing some sanity checks on performance of the service with thousands of shards) ## Summary of changes - Enable the `diesel:r2d2` feature, which provides an async connection pool - Acquire a connection before entering spawn_blocking for a database transaction (recall that diesel's interface is sync) - Set a connection pool size of 99 to fit within default postgres limit (100) - Also set the tokio blocking thread count to accomodate the same number of blocking tasks (the only thing we use spawn_blocking for is database calls). --- Cargo.lock | 23 +++++++++++ control_plane/attachment_service/Cargo.toml | 3 +- control_plane/attachment_service/src/main.rs | 15 ++++++- .../attachment_service/src/persistence.rs | 41 ++++++++++++++----- .../attachment_service/src/service.rs | 4 +- workspace_hack/Cargo.toml | 3 +- 6 files changed, 74 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2b2777408..a25725f90d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -289,6 +289,7 @@ dependencies = [ "pageserver_api", "pageserver_client", "postgres_connection", + "r2d2", "reqwest", "serde", "serde_json", @@ -1651,6 +1652,7 @@ dependencies = [ "diesel_derives", "itoa", "pq-sys", + "r2d2", "serde_json", ] @@ -4166,6 +4168,17 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot 0.12.1", + "scheduled-thread-pool", +] + [[package]] name = "rand" version = "0.7.3" @@ -4879,6 +4892,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot 0.12.1", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -6807,6 +6829,7 @@ dependencies = [ "clap_builder", "crossbeam-utils", "diesel", + "diesel_derives", "either", "fail", "futures-channel", diff --git a/control_plane/attachment_service/Cargo.toml b/control_plane/attachment_service/Cargo.toml index 3a65153c41..0b93211dbc 100644 --- a/control_plane/attachment_service/Cargo.toml +++ b/control_plane/attachment_service/Cargo.toml @@ -24,8 +24,9 @@ tokio.workspace = true tokio-util.workspace = true tracing.workspace = true -diesel = { version = "2.1.4", features = ["serde_json", "postgres"] } +diesel = { version = "2.1.4", features = ["serde_json", "postgres", "r2d2"] } diesel_migrations = { version = "2.1.0" } +r2d2 = { version = "0.8.10" } utils = { path = "../../libs/utils/" } metrics = { path = "../../libs/metrics/" } diff --git a/control_plane/attachment_service/src/main.rs b/control_plane/attachment_service/src/main.rs index bc8a8786c2..7229a2517b 100644 --- a/control_plane/attachment_service/src/main.rs +++ b/control_plane/attachment_service/src/main.rs @@ -170,6 +170,7 @@ impl Secrets { } } +/// Execute the diesel migrations that are built into this binary async fn migration_run(database_url: &str) -> anyhow::Result<()> { use diesel::PgConnection; use diesel_migrations::{HarnessWithOutput, MigrationHarness}; @@ -183,8 +184,18 @@ async fn migration_run(database_url: &str) -> anyhow::Result<()> { Ok(()) } -#[tokio::main] -async fn main() -> anyhow::Result<()> { +fn main() -> anyhow::Result<()> { + tokio::runtime::Builder::new_current_thread() + // We use spawn_blocking for database operations, so require approximately + // as many blocking threads as we will open database connections. + .max_blocking_threads(Persistence::MAX_CONNECTIONS as usize) + .enable_all() + .build() + .unwrap() + .block_on(async_main()) +} + +async fn async_main() -> anyhow::Result<()> { let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate())); logging::init( diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index 574441c409..db487bcec6 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::str::FromStr; +use std::time::Duration; use camino::Utf8Path; use camino::Utf8PathBuf; @@ -44,7 +45,7 @@ use crate::PlacementPolicy; /// updated, and reads of nodes are always from memory, not the database. We only require that /// we can UPDATE a node's scheduling mode reasonably quickly to mark a bad node offline. pub struct Persistence { - database_url: String, + connection_pool: diesel::r2d2::Pool>, // In test environments, we support loading+saving a JSON file. This is temporary, for the benefit of // test_compatibility.py, so that we don't have to commit to making the database contents fully backward/forward @@ -64,6 +65,8 @@ pub(crate) enum DatabaseError { Query(#[from] diesel::result::Error), #[error(transparent)] Connection(#[from] diesel::result::ConnectionError), + #[error(transparent)] + ConnectionPool(#[from] r2d2::Error), #[error("Logical error: {0}")] Logical(String), } @@ -71,9 +74,31 @@ pub(crate) enum DatabaseError { pub(crate) type DatabaseResult = Result; impl Persistence { + // The default postgres connection limit is 100. We use up to 99, to leave one free for a human admin under + // normal circumstances. This assumes we have exclusive use of the database cluster to which we connect. + pub const MAX_CONNECTIONS: u32 = 99; + + // We don't want to keep a lot of connections alive: close them down promptly if they aren't being used. + const IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); + const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(60); + pub fn new(database_url: String, json_path: Option) -> Self { + let manager = diesel::r2d2::ConnectionManager::::new(database_url); + + // We will use a connection pool: this is primarily to _limit_ our connection count, rather than to optimize time + // to execute queries (database queries are not generally on latency-sensitive paths). + let connection_pool = diesel::r2d2::Pool::builder() + .max_size(Self::MAX_CONNECTIONS) + .max_lifetime(Some(Self::MAX_CONNECTION_LIFETIME)) + .idle_timeout(Some(Self::IDLE_CONNECTION_TIMEOUT)) + // Always keep at least one connection ready to go + .min_idle(Some(1)) + .test_on_check_out(true) + .build(manager) + .expect("Could not build connection pool"); + Self { - database_url, + connection_pool, json_path, } } @@ -84,14 +109,10 @@ impl Persistence { F: Fn(&mut PgConnection) -> DatabaseResult + Send + 'static, R: Send + 'static, { - let database_url = self.database_url.clone(); - tokio::task::spawn_blocking(move || -> DatabaseResult { - // TODO: connection pooling, such as via diesel::r2d2 - let mut conn = PgConnection::establish(&database_url)?; - func(&mut conn) - }) - .await - .expect("Task panic") + let mut conn = self.connection_pool.get()?; + tokio::task::spawn_blocking(move || -> DatabaseResult { func(&mut conn) }) + .await + .expect("Task panic") } /// When a node is first registered, persist it before using it for anything diff --git a/control_plane/attachment_service/src/service.rs b/control_plane/attachment_service/src/service.rs index 6f0e3ebb74..febee1aa0d 100644 --- a/control_plane/attachment_service/src/service.rs +++ b/control_plane/attachment_service/src/service.rs @@ -103,7 +103,9 @@ impl From for ApiError { match err { DatabaseError::Query(e) => ApiError::InternalServerError(e.into()), // FIXME: ApiError doesn't have an Unavailable variant, but ShuttingDown maps to 503. - DatabaseError::Connection(_e) => ApiError::ShuttingDown, + DatabaseError::Connection(_) | DatabaseError::ConnectionPool(_) => { + ApiError::ShuttingDown + } DatabaseError::Logical(reason) => { ApiError::InternalServerError(anyhow::anyhow!(reason)) } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 74464dd4c8..70b238913d 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -29,7 +29,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } -diesel = { version = "2", features = ["postgres", "serde_json"] } +diesel = { version = "2", features = ["postgres", "r2d2", "serde_json"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures-channel = { version = "0.3", features = ["sink"] } @@ -90,6 +90,7 @@ anyhow = { version = "1", features = ["backtrace"] } bytes = { version = "1", features = ["serde"] } cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } +diesel_derives = { version = "2", features = ["32-column-tables", "postgres", "r2d2", "with-deprecated"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } From 090a789408e4bd95656132248bdbcbdba0fd3c4a Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 7 Feb 2024 13:24:10 +0000 Subject: [PATCH 03/81] storage controller: use PUT instead of POST (#6659) This was a typo, the server expects PUT. --- control_plane/attachment_service/src/compute_hook.rs | 2 +- test_runner/regress/test_sharding_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/control_plane/attachment_service/src/compute_hook.rs b/control_plane/attachment_service/src/compute_hook.rs index 4ca26431ca..0d3610aafa 100644 --- a/control_plane/attachment_service/src/compute_hook.rs +++ b/control_plane/attachment_service/src/compute_hook.rs @@ -170,7 +170,7 @@ impl ComputeHook { reconfigure_request: &ComputeHookNotifyRequest, cancel: &CancellationToken, ) -> Result<(), NotifyError> { - let req = client.request(Method::POST, url); + let req = client.request(Method::PUT, url); let req = if let Some(value) = &self.authorization_header { req.header(reqwest::header::AUTHORIZATION, value) } else { diff --git a/test_runner/regress/test_sharding_service.py b/test_runner/regress/test_sharding_service.py index ee57fcb2cf..fd811a9d02 100644 --- a/test_runner/regress/test_sharding_service.py +++ b/test_runner/regress/test_sharding_service.py @@ -310,7 +310,7 @@ def test_sharding_service_compute_hook( notifications.append(request.json) return Response(status=200) - httpserver.expect_request("/notify", method="POST").respond_with_handler(handler) + httpserver.expect_request("/notify", method="PUT").respond_with_handler(handler) # Start running env = neon_env_builder.init_start() From 75f1a01d4aba488012c9fd86b56b6dcf46726c92 Mon Sep 17 00:00:00 2001 From: Abhijeet Patil Date: Wed, 7 Feb 2024 16:14:10 +0000 Subject: [PATCH 04/81] Optimise e2e run (#6513) ## Problem We have finite amount of runners and intermediate results are often wanted before a PR is ready for merging. Currently all PRs get e2e tests run and this creates a lot of throwaway e2e results which may or may not get to start or complete before a new push. ## Summary of changes 1. Skip e2e test when PR is in draft mode 2. Run e2e when PR status changes from draft to ready for review (change this to having its trigger in below PR and update results of build and test) 3. Abstract e2e test in a Separate workflow and call it from the main workflow for the e2e test 5. Add a label, if that label is present run e2e test in draft (run-e2e-test-in-draft) 6. Auto add a label(approve to ci) so that all the external contributors PR , e2e run in draft 7. Document the new label changes and the above behaviour Draft PR : https://github.com/neondatabase/neon/actions/runs/7729128470 Ready To Review : https://github.com/neondatabase/neon/actions/runs/7733779916 Draft PR with label : https://github.com/neondatabase/neon/actions/runs/7725691012/job/21062432342 and https://github.com/neondatabase/neon/actions/runs/7733854028 ## Checklist before requesting a review - [x] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist --------- Co-authored-by: Alexander Bayandin --- .github/workflows/approved-for-ci-run.yml | 1 + .github/workflows/build_and_test.yml | 48 +-------- .github/workflows/trigger-e2e-tests.yml | 118 ++++++++++++++++++++++ CONTRIBUTING.md | 3 + 4 files changed, 126 insertions(+), 44 deletions(-) create mode 100644 .github/workflows/trigger-e2e-tests.yml diff --git a/.github/workflows/approved-for-ci-run.yml b/.github/workflows/approved-for-ci-run.yml index 5b21011b83..ae2f173b47 100644 --- a/.github/workflows/approved-for-ci-run.yml +++ b/.github/workflows/approved-for-ci-run.yml @@ -93,6 +93,7 @@ jobs: --body-file "body.md" \ --head "${BRANCH}" \ --base "main" \ + --label "run-e2e-tests-in-draft" \ --draft fi diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f12f020634..078916e1ea 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -22,7 +22,7 @@ env: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }} # A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix - E2E_CONCURRENCY_GROUP: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} + E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} jobs: check-permissions: @@ -692,50 +692,10 @@ jobs: }) trigger-e2e-tests: + if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' }} needs: [ check-permissions, promote-images, tag ] - runs-on: [ self-hosted, gen3, small ] - container: - image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned - options: --init - steps: - - name: Set PR's status to pending and request a remote CI test - run: | - # For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit - # but we need to use a real sha of a latest commit in the PR's branch for the e2e job, - # to place a job run status update later. - COMMIT_SHA=${{ github.event.pull_request.head.sha }} - # For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those - COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}} - - REMOTE_REPO="${{ github.repository_owner }}/cloud" - - curl -f -X POST \ - https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \ - -H "Accept: application/vnd.github.v3+json" \ - --user "${{ secrets.CI_ACCESS_TOKEN }}" \ - --data \ - "{ - \"state\": \"pending\", - \"context\": \"neon-cloud-e2e\", - \"description\": \"[$REMOTE_REPO] Remote CI job is about to start\" - }" - - curl -f -X POST \ - https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \ - -H "Accept: application/vnd.github.v3+json" \ - --user "${{ secrets.CI_ACCESS_TOKEN }}" \ - --data \ - "{ - \"ref\": \"main\", - \"inputs\": { - \"ci_job_name\": \"neon-cloud-e2e\", - \"commit_hash\": \"$COMMIT_SHA\", - \"remote_repo\": \"${{ github.repository }}\", - \"storage_image_tag\": \"${{ needs.tag.outputs.build-tag }}\", - \"compute_image_tag\": \"${{ needs.tag.outputs.build-tag }}\", - \"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\" - } - }" + uses: ./.github/workflows/trigger-e2e-tests.yml + secrets: inherit neon-image: needs: [ check-permissions, build-buildtools-image, tag ] diff --git a/.github/workflows/trigger-e2e-tests.yml b/.github/workflows/trigger-e2e-tests.yml new file mode 100644 index 0000000000..2776033805 --- /dev/null +++ b/.github/workflows/trigger-e2e-tests.yml @@ -0,0 +1,118 @@ +name: Trigger E2E Tests + +on: + pull_request: + types: + - ready_for_review + workflow_call: + +defaults: + run: + shell: bash -euxo pipefail {0} + +env: + # A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix + E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }} + +jobs: + cancel-previous-e2e-tests: + if: github.event_name == 'pull_request' + runs-on: ubuntu-latest + + steps: + - name: Cancel previous e2e-tests runs for this PR + env: + GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }} + run: | + gh workflow --repo neondatabase/cloud \ + run cancel-previous-in-concurrency-group.yml \ + --field concurrency_group="${{ env.E2E_CONCURRENCY_GROUP }}" + + tag: + runs-on: [ ubuntu-latest ] + outputs: + build-tag: ${{ steps.build-tag.outputs.tag }} + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Get build tag + env: + GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }} + CURRENT_BRANCH: ${{ github.head_ref || github.ref_name }} + CURRENT_SHA: ${{ github.event.pull_request.head.sha || github.sha }} + run: | + if [[ "$GITHUB_REF_NAME" == "main" ]]; then + echo "tag=$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT + elif [[ "$GITHUB_REF_NAME" == "release" ]]; then + echo "tag=release-$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT + else + echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'" + BUILD_AND_TEST_RUN_ID=$(gh run list -b $CURRENT_BRANCH -c $CURRENT_SHA -w 'Build and Test' -L 1 --json databaseId --jq '.[].databaseId') + echo "tag=$BUILD_AND_TEST_RUN_ID" | tee -a $GITHUB_OUTPUT + fi + id: build-tag + + trigger-e2e-tests: + needs: [ tag ] + runs-on: [ self-hosted, gen3, small ] + env: + TAG: ${{ needs.tag.outputs.build-tag }} + container: + image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned + options: --init + steps: + - name: check if ecr image are present + run: | + for REPO in neon compute-tools compute-node-v14 vm-compute-node-v14 compute-node-v15 vm-compute-node-v15 compute-node-v16 vm-compute-node-v16; do + OUTPUT=$(aws ecr describe-images --repository-name ${REPO} --region eu-central-1 --query "imageDetails[?imageTags[?contains(@, '${TAG}')]]" --output text) + if [ "$OUTPUT" == "" ]; then + echo "$REPO with image tag $TAG not found" >> $GITHUB_OUTPUT + exit 1 + fi + done + + - name: Set PR's status to pending and request a remote CI test + run: | + # For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit + # but we need to use a real sha of a latest commit in the PR's branch for the e2e job, + # to place a job run status update later. + COMMIT_SHA=${{ github.event.pull_request.head.sha }} + # For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those + COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}} + + REMOTE_REPO="${{ github.repository_owner }}/cloud" + + curl -f -X POST \ + https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \ + -H "Accept: application/vnd.github.v3+json" \ + --user "${{ secrets.CI_ACCESS_TOKEN }}" \ + --data \ + "{ + \"state\": \"pending\", + \"context\": \"neon-cloud-e2e\", + \"description\": \"[$REMOTE_REPO] Remote CI job is about to start\" + }" + + curl -f -X POST \ + https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \ + -H "Accept: application/vnd.github.v3+json" \ + --user "${{ secrets.CI_ACCESS_TOKEN }}" \ + --data \ + "{ + \"ref\": \"main\", + \"inputs\": { + \"ci_job_name\": \"neon-cloud-e2e\", + \"commit_hash\": \"$COMMIT_SHA\", + \"remote_repo\": \"${{ github.repository }}\", + \"storage_image_tag\": \"${TAG}\", + \"compute_image_tag\": \"${TAG}\", + \"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\" + } + }" + \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7e177693fa..2e447fba47 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,6 +54,9 @@ _An instruction for maintainers_ - If and only if it looks **safe** (i.e. it doesn't contain any malicious code which could expose secrets or harm the CI), then: - Press the "Approve and run" button in GitHub UI - Add the `approved-for-ci-run` label to the PR + - Currently draft PR will skip e2e test (only for internal contributors). After turning the PR 'Ready to Review' CI will trigger e2e test + - Add `run-e2e-tests-in-draft` label to run e2e test in draft PR (override above behaviour) + - The `approved-for-ci-run` workflow will add `run-e2e-tests-in-draft` automatically to run e2e test for external contributors Repeat all steps after any change to the PR. - When the changes are ready to get merged — merge the original PR (not the internal one) From 7b49e5e5c334bc8d07232f385d08e370ba85fb5a Mon Sep 17 00:00:00 2001 From: Sasha Krassovsky Date: Wed, 7 Feb 2024 07:55:55 -0900 Subject: [PATCH 05/81] Remove compute migrations feature flag (#6653) --- compute_tools/src/compute.rs | 11 +++++------ libs/compute_api/src/spec.rs | 3 --- test_runner/fixtures/neon_fixtures.py | 5 +---- test_runner/regress/test_migrations.py | 2 +- test_runner/regress/test_neon_superuser.py | 4 ++-- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 098e06cca9..0ca1a47fbf 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -773,12 +773,11 @@ impl ComputeNode { // 'Close' connection drop(client); - if self.has_feature(ComputeFeature::Migrations) { - thread::spawn(move || { - let mut client = Client::connect(connstr.as_str(), NoTls)?; - handle_migrations(&mut client) - }); - } + // Run migrations separately to not hold up cold starts + thread::spawn(move || { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + handle_migrations(&mut client) + }); Ok(()) } diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 5361d14004..13ac18e0c5 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -90,9 +90,6 @@ pub enum ComputeFeature { /// track short-lived connections as user activity. ActivityMonitorExperimental, - /// Enable running migrations - Migrations, - /// This is a special feature flag that is used to represent unknown feature flags. /// Basically all unknown to enum flags are represented as this one. See unit test /// `parse_unknown_features()` for more details. diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index bf7c6ccc14..4491655aeb 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3131,10 +3131,7 @@ class Endpoint(PgProtocol): log.info(json.dumps(dict(data_dict, **kwargs))) json.dump(dict(data_dict, **kwargs), file, indent=4) - # Please note: if you didn't respec this endpoint to have the `migrations` - # feature, this function will probably fail because neon_migration.migration_id - # won't exist. This is temporary - soon we'll get rid of the feature flag and - # migrations will be enabled for everyone. + # Please note: Migrations only run if pg_skip_catalog_updates is false def wait_for_migrations(self): with self.cursor() as cur: diff --git a/test_runner/regress/test_migrations.py b/test_runner/regress/test_migrations.py index 30dd54a8c1..8954810451 100644 --- a/test_runner/regress/test_migrations.py +++ b/test_runner/regress/test_migrations.py @@ -10,7 +10,7 @@ def test_migrations(neon_simple_env: NeonEnv): endpoint = env.endpoints.create("test_migrations") log_path = endpoint.endpoint_path() / "compute.log" - endpoint.respec(skip_pg_catalog_updates=False, features=["migrations"]) + endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() endpoint.wait_for_migrations() diff --git a/test_runner/regress/test_neon_superuser.py b/test_runner/regress/test_neon_superuser.py index eff2cadabf..34f1e64b34 100644 --- a/test_runner/regress/test_neon_superuser.py +++ b/test_runner/regress/test_neon_superuser.py @@ -12,10 +12,10 @@ def test_neon_superuser(neon_simple_env: NeonEnv, pg_version: PgVersion): env.neon_cli.create_branch("test_neon_superuser_subscriber") sub = env.endpoints.create("test_neon_superuser_subscriber") - pub.respec(skip_pg_catalog_updates=False, features=["migrations"]) + pub.respec(skip_pg_catalog_updates=False) pub.start() - sub.respec(skip_pg_catalog_updates=False, features=["migrations"]) + sub.respec(skip_pg_catalog_updates=False) sub.start() pub.wait_for_migrations() From 51f9385b1bd60f3152a580332ba4b19ec131f89a Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 7 Feb 2024 18:47:55 +0100 Subject: [PATCH 06/81] live-reconfigurable virtual_file::IoEngine (#6552) This PR adds an API to live-reconfigure the VirtualFile io engine. It also adds a flag to `pagebench get-page-latest-lsn`, which is where I found this functionality to be useful: it helps compare the io engines in a benchmark without re-compiling a release build, which took ~50s on the i3en.3xlarge where I was doing the benchmark. Switching the IO engine is completely safe at runtime. --- libs/pageserver_api/src/models.rs | 21 +++ pageserver/client/src/mgmt_api.rs | 12 ++ pageserver/ctl/src/layer_map_analyzer.rs | 2 +- pageserver/ctl/src/layers.rs | 4 +- pageserver/ctl/src/main.rs | 2 +- .../pagebench/src/cmd/getpage_latest_lsn.rs | 8 ++ pageserver/src/http/routes.rs | 10 ++ pageserver/src/virtual_file.rs | 5 +- pageserver/src/virtual_file/io_engine.rs | 130 +++++++++++------- pageserver/src/virtual_file/open_options.rs | 7 +- 10 files changed, 144 insertions(+), 57 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 5a638df9cc..c08cacb822 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -649,6 +649,27 @@ pub struct WalRedoManagerStatus { pub pid: Option, } +pub mod virtual_file { + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + strum_macros::EnumString, + strum_macros::Display, + serde_with::DeserializeFromStr, + serde_with::SerializeDisplay, + Debug, + )] + #[strum(serialize_all = "kebab-case")] + pub enum IoEngineKind { + StdFs, + #[cfg(target_os = "linux")] + TokioEpollUring, + } +} + // Wrapped in libpq CopyData #[derive(PartialEq, Eq, Debug)] pub enum PagestreamFeMessage { diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index 91b9afa026..8abe58e1a2 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -339,4 +339,16 @@ impl Client { .await .map_err(Error::ReceiveBody) } + + pub async fn put_io_engine( + &self, + engine: &pageserver_api::models::virtual_file::IoEngineKind, + ) -> Result<()> { + let uri = format!("{}/v1/io_engine", self.mgmt_api_endpoint); + self.request(Method::PUT, uri, engine) + .await? + .json() + .await + .map_err(Error::ReceiveBody) + } } diff --git a/pageserver/ctl/src/layer_map_analyzer.rs b/pageserver/ctl/src/layer_map_analyzer.rs index eb5c3f15cf..42c4e9ff48 100644 --- a/pageserver/ctl/src/layer_map_analyzer.rs +++ b/pageserver/ctl/src/layer_map_analyzer.rs @@ -142,7 +142,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> { let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); // Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree. - pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs); + pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); pageserver::page_cache::init(100); let mut total_delta_layers = 0usize; diff --git a/pageserver/ctl/src/layers.rs b/pageserver/ctl/src/layers.rs index dbbcfedac0..27efa6d028 100644 --- a/pageserver/ctl/src/layers.rs +++ b/pageserver/ctl/src/layers.rs @@ -59,7 +59,7 @@ pub(crate) enum LayerCmd { async fn read_delta_file(path: impl AsRef, ctx: &RequestContext) -> Result<()> { let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path"); - virtual_file::init(10, virtual_file::IoEngineKind::StdFs); + virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); page_cache::init(100); let file = FileBlockReader::new(VirtualFile::open(path).await?); let summary_blk = file.read_blk(0, ctx).await?; @@ -187,7 +187,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { new_tenant_id, new_timeline_id, } => { - pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs); + pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); pageserver::page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); diff --git a/pageserver/ctl/src/main.rs b/pageserver/ctl/src/main.rs index 3c90933fe9..e73d961e36 100644 --- a/pageserver/ctl/src/main.rs +++ b/pageserver/ctl/src/main.rs @@ -123,7 +123,7 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> { async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> { // Basic initialization of things that don't change after startup - virtual_file::init(10, virtual_file::IoEngineKind::StdFs); + virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); dump_layerfile_from_path(path, true, &ctx).await diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index aa809d8d26..647f571e59 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -51,6 +51,10 @@ pub(crate) struct Args { /// It doesn't get invalidated if the keyspace changes under the hood, e.g., due to new ingested data or compaction. #[clap(long)] keyspace_cache: Option, + /// Before starting the benchmark, live-reconfigure the pageserver to use the given + /// [`pageserver_api::models::virtual_file::IoEngineKind`]. + #[clap(long)] + set_io_engine: Option, targets: Option>, } @@ -109,6 +113,10 @@ async fn main_impl( args.pageserver_jwt.as_deref(), )); + if let Some(engine_str) = &args.set_io_engine { + mgmt_api_client.put_io_engine(engine_str).await?; + } + // discover targets let timelines: Vec = crate::util::cli::targets::discover( &mgmt_api_client, diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 792089ebe7..ebcb27fa08 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -1908,6 +1908,15 @@ async fn post_tracing_event_handler( json_response(StatusCode::OK, ()) } +async fn put_io_engine_handler( + mut r: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + let kind: crate::virtual_file::IoEngineKind = json_request(&mut r).await?; + crate::virtual_file::io_engine::set(kind); + json_response(StatusCode::OK, ()) +} + /// Common functionality of all the HTTP API handlers. /// /// - Adds a tracing span to each request (by `request_span`) @@ -2165,5 +2174,6 @@ pub fn make_router( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/keyspace", |r| testing_api_handler("read out the keyspace", r, timeline_collect_keyspace), ) + .put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler)) .any(handler_404)) } diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 066f06c88f..059a6596d3 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -28,9 +28,10 @@ use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use tokio::time::Instant; use utils::fs_ext; -mod io_engine; +pub use pageserver_api::models::virtual_file as api; +pub(crate) mod io_engine; mod open_options; -pub use io_engine::IoEngineKind; +pub(crate) use io_engine::IoEngineKind; pub(crate) use open_options::*; /// diff --git a/pageserver/src/virtual_file/io_engine.rs b/pageserver/src/virtual_file/io_engine.rs index f7b46fe653..892affa326 100644 --- a/pageserver/src/virtual_file/io_engine.rs +++ b/pageserver/src/virtual_file/io_engine.rs @@ -7,67 +7,100 @@ //! //! Then use [`get`] and [`super::OpenOptions`]. -#[derive( - Copy, - Clone, - PartialEq, - Eq, - Hash, - strum_macros::EnumString, - strum_macros::Display, - serde_with::DeserializeFromStr, - serde_with::SerializeDisplay, - Debug, -)] -#[strum(serialize_all = "kebab-case")] -pub enum IoEngineKind { +pub(crate) use super::api::IoEngineKind; +#[derive(Clone, Copy)] +#[repr(u8)] +pub(crate) enum IoEngine { + NotSet, StdFs, #[cfg(target_os = "linux")] TokioEpollUring, } -static IO_ENGINE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); - -#[cfg(not(test))] -pub(super) fn init(engine: IoEngineKind) { - if IO_ENGINE.set(engine).is_err() { - panic!("called twice"); +impl From for IoEngine { + fn from(value: IoEngineKind) -> Self { + match value { + IoEngineKind::StdFs => IoEngine::StdFs, + #[cfg(target_os = "linux")] + IoEngineKind::TokioEpollUring => IoEngine::TokioEpollUring, + } } - crate::metrics::virtual_file_io_engine::KIND - .with_label_values(&[&format!("{engine}")]) - .set(1); } -pub(super) fn get() -> &'static IoEngineKind { - #[cfg(test)] - { - let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE"; - IO_ENGINE.get_or_init(|| match std::env::var(env_var_name) { - Ok(v) => match v.parse::() { - Ok(engine_kind) => engine_kind, - Err(e) => { - panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}") - } - }, - Err(std::env::VarError::NotPresent) => { - crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE - .parse() - .unwrap() - } - Err(std::env::VarError::NotUnicode(_)) => { - panic!("env var {env_var_name} is not unicode"); - } +impl TryFrom for IoEngine { + type Error = u8; + + fn try_from(value: u8) -> Result { + Ok(match value { + v if v == (IoEngine::NotSet as u8) => IoEngine::NotSet, + v if v == (IoEngine::StdFs as u8) => IoEngine::StdFs, + #[cfg(target_os = "linux")] + v if v == (IoEngine::TokioEpollUring as u8) => IoEngine::TokioEpollUring, + x => return Err(x), }) } - #[cfg(not(test))] - IO_ENGINE.get().unwrap() } -use std::os::unix::prelude::FileExt; +static IO_ENGINE: AtomicU8 = AtomicU8::new(IoEngine::NotSet as u8); + +pub(crate) fn set(engine_kind: IoEngineKind) { + let engine: IoEngine = engine_kind.into(); + IO_ENGINE.store(engine as u8, std::sync::atomic::Ordering::Relaxed); + #[cfg(not(test))] + { + let metric = &crate::metrics::virtual_file_io_engine::KIND; + metric.reset(); + metric + .with_label_values(&[&format!("{engine_kind}")]) + .set(1); + } +} + +#[cfg(not(test))] +pub(super) fn init(engine_kind: IoEngineKind) { + set(engine_kind); +} + +pub(super) fn get() -> IoEngine { + let cur = IoEngine::try_from(IO_ENGINE.load(Ordering::Relaxed)).unwrap(); + if cfg!(test) { + let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE"; + match cur { + IoEngine::NotSet => { + let kind = match std::env::var(env_var_name) { + Ok(v) => match v.parse::() { + Ok(engine_kind) => engine_kind, + Err(e) => { + panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}") + } + }, + Err(std::env::VarError::NotPresent) => { + crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE + .parse() + .unwrap() + } + Err(std::env::VarError::NotUnicode(_)) => { + panic!("env var {env_var_name} is not unicode"); + } + }; + self::set(kind); + self::get() + } + x => x, + } + } else { + cur + } +} + +use std::{ + os::unix::prelude::FileExt, + sync::atomic::{AtomicU8, Ordering}, +}; use super::FileGuard; -impl IoEngineKind { +impl IoEngine { pub(super) async fn read_at( &self, file_guard: FileGuard, @@ -78,7 +111,8 @@ impl IoEngineKind { B: tokio_epoll_uring::BoundedBufMut + Send, { match self { - IoEngineKind::StdFs => { + IoEngine::NotSet => panic!("not initialized"), + IoEngine::StdFs => { // SAFETY: `dst` only lives at most as long as this match arm, during which buf remains valid memory. let dst = unsafe { std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) @@ -96,7 +130,7 @@ impl IoEngineKind { ((file_guard, buf), res) } #[cfg(target_os = "linux")] - IoEngineKind::TokioEpollUring => { + IoEngine::TokioEpollUring => { let system = tokio_epoll_uring::thread_local_system().await; let (resources, res) = system.read(file_guard, offset, buf).await; ( diff --git a/pageserver/src/virtual_file/open_options.rs b/pageserver/src/virtual_file/open_options.rs index 1e5ffe15cc..f75edb0bac 100644 --- a/pageserver/src/virtual_file/open_options.rs +++ b/pageserver/src/virtual_file/open_options.rs @@ -1,6 +1,6 @@ //! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`]; -use super::IoEngineKind; +use super::io_engine::IoEngine; use std::{os::fd::OwnedFd, path::Path}; #[derive(Debug, Clone)] @@ -13,9 +13,10 @@ pub enum OpenOptions { impl Default for OpenOptions { fn default() -> Self { match super::io_engine::get() { - IoEngineKind::StdFs => Self::StdFs(std::fs::OpenOptions::new()), + IoEngine::NotSet => panic!("io engine not set"), + IoEngine::StdFs => Self::StdFs(std::fs::OpenOptions::new()), #[cfg(target_os = "linux")] - IoEngineKind::TokioEpollUring => { + IoEngine::TokioEpollUring => { Self::TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions::new()) } } From 2e9b1f7aaf61d5886f312628d4fb54a1526317f2 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 6 Feb 2024 14:34:20 -0600 Subject: [PATCH 07/81] Update Postgres 14 to 14.11 --- vendor/postgres-v14 | 2 +- vendor/revisions.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index be7a65fe67..018fb05201 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit be7a65fe67dc81d85bbcbebb13e00d94715f4b88 +Subproject commit 018fb052011081dc2733d3118d12e5c36df6eba1 diff --git a/vendor/revisions.json b/vendor/revisions.json index 80699839ba..c2f9244116 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { "postgres-v16": "f7ea954989a2e7901f858779cff55259f203479a", "postgres-v15": "81e16cd537053f49e175d4a08ab7c8aec3d9b535", - "postgres-v14": "be7a65fe67dc81d85bbcbebb13e00d94715f4b88" + "postgres-v14": "018fb052011081dc2733d3118d12e5c36df6eba1" } From 5541244dc4736208e802dd60d6f9861392d9b743 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 6 Feb 2024 14:35:37 -0600 Subject: [PATCH 08/81] Update Postgres 15 to 15.6 --- vendor/postgres-v15 | 2 +- vendor/revisions.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 81e16cd537..6ee78a3c29 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 81e16cd537053f49e175d4a08ab7c8aec3d9b535 +Subproject commit 6ee78a3c29e33cafd85ba09568b6b5eb031d29b9 diff --git a/vendor/revisions.json b/vendor/revisions.json index c2f9244116..c7076231e5 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { "postgres-v16": "f7ea954989a2e7901f858779cff55259f203479a", - "postgres-v15": "81e16cd537053f49e175d4a08ab7c8aec3d9b535", + "postgres-v15": "6ee78a3c29e33cafd85ba09568b6b5eb031d29b9", "postgres-v14": "018fb052011081dc2733d3118d12e5c36df6eba1" } From 128fae70548f06ebc8ac44c38576c993ae6cba52 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Tue, 6 Feb 2024 14:37:21 -0600 Subject: [PATCH 09/81] Update Postgres 16 to 16.2 --- libs/walproposer/src/walproposer.rs | 7 +++++-- vendor/postgres-v16 | 2 +- vendor/revisions.json | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 7251545792..8ab8fb1a07 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -453,9 +453,12 @@ mod tests { event_mask: 0, }), expected_messages: vec![ - // Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160001, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 }) + // TODO: When updating Postgres versions, this test will cause + // problems. Postgres version in message needs updating. + // + // Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160002, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 }) vec![ - 103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110, 147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147, 188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1, diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index f7ea954989..550cdd26d4 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit f7ea954989a2e7901f858779cff55259f203479a +Subproject commit 550cdd26d445afdd26b15aa93c8c2f3dc52f8361 diff --git a/vendor/revisions.json b/vendor/revisions.json index c7076231e5..91ebb8cb34 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { - "postgres-v16": "f7ea954989a2e7901f858779cff55259f203479a", + "postgres-v16": "550cdd26d445afdd26b15aa93c8c2f3dc52f8361", "postgres-v15": "6ee78a3c29e33cafd85ba09568b6b5eb031d29b9", "postgres-v14": "018fb052011081dc2733d3118d12e5c36df6eba1" } From 3bd2a4fd56803b0aabb87e9076872ceff0147a77 Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 7 Feb 2024 19:14:18 +0000 Subject: [PATCH 10/81] control_plane: avoid feedback loop with /location_config if compute hook fails. (#6668) ## Problem The existing behavior isn't exactly incorrect, but is operationally risky: if the control plane compute hook breaks, then all the control plane operations trying to call /location_config will end up retrying forever, which could put more load on the system. ## Summary of changes - Treat 404s as fatal errors to do fewer retries: a 404 either indicates we have the wrong URL, or some control plane bug is failing to recognize our tenant ID as existing. - Do not return an error on reconcilation errors in a non-creating /location_config response: this allows the control plane to finish its Operation (and we will eventually retry the compute notification later) --- control_plane/attachment_service/src/compute_hook.rs | 2 +- control_plane/attachment_service/src/service.rs | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/control_plane/attachment_service/src/compute_hook.rs b/control_plane/attachment_service/src/compute_hook.rs index 0d3610aafa..5bd1b6bf09 100644 --- a/control_plane/attachment_service/src/compute_hook.rs +++ b/control_plane/attachment_service/src/compute_hook.rs @@ -240,7 +240,7 @@ impl ComputeHook { let client = reqwest::Client::new(); backoff::retry( || self.do_notify_iteration(&client, url, &reconfigure_request, cancel), - |e| matches!(e, NotifyError::Fatal(_)), + |e| matches!(e, NotifyError::Fatal(_) | NotifyError::Unexpected(_)), 3, 10, "Send compute notification", diff --git a/control_plane/attachment_service/src/service.rs b/control_plane/attachment_service/src/service.rs index febee1aa0d..1db1906df8 100644 --- a/control_plane/attachment_service/src/service.rs +++ b/control_plane/attachment_service/src/service.rs @@ -989,7 +989,15 @@ impl Service { .collect(); } else { // This was an update, wait for reconciliation - self.await_waiters(waiters).await?; + if let Err(e) = self.await_waiters(waiters).await { + // Do not treat a reconcile error as fatal: we have already applied any requested + // Intent changes, and the reconcile can fail for external reasons like unavailable + // compute notification API. In these cases, it is important that we do not + // cause the cloud control plane to retry forever on this API. + tracing::warn!( + "Failed to reconcile after /location_config: {e}, returning success anyway" + ); + } } Ok(result) From c561ad4e2e900409141e8c6c9963bab90288fd12 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 7 Feb 2024 20:39:52 +0100 Subject: [PATCH 11/81] feat: expose locked memory in pageserver `/metrics` (#6669) context: https://github.com/neondatabase/neon/issues/6667 --- Cargo.lock | 3 ++ Cargo.toml | 1 + libs/metrics/Cargo.toml | 3 ++ libs/metrics/src/lib.rs | 2 + libs/metrics/src/more_process_metrics.rs | 54 ++++++++++++++++++++++++ pageserver/src/bin/pageserver.rs | 2 + 6 files changed, 65 insertions(+) create mode 100644 libs/metrics/src/more_process_metrics.rs diff --git a/Cargo.lock b/Cargo.lock index a25725f90d..bf1ecfa89d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2869,6 +2869,7 @@ dependencies = [ "chrono", "libc", "once_cell", + "procfs", "prometheus", "rand 0.8.5", "rand_distr", @@ -3986,6 +3987,8 @@ checksum = "b1de8dacb0873f77e6aefc6d71e044761fcc68060290f5b1089fcdf84626bb69" dependencies = [ "bitflags 1.3.2", "byteorder", + "chrono", + "flate2", "hex", "lazy_static", "rustix 0.36.16", diff --git a/Cargo.toml b/Cargo.toml index 271edee742..6a2c3fa563 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,6 +113,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["zstd"] } parquet_derive = "49.0.0" pbkdf2 = { version = "0.12.1", features = ["simple", "std"] } pin-project-lite = "0.2" +procfs = "0.14" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" diff --git a/libs/metrics/Cargo.toml b/libs/metrics/Cargo.toml index a547d492df..f6a49a0166 100644 --- a/libs/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -13,6 +13,9 @@ twox-hash.workspace = true workspace_hack.workspace = true +[target.'cfg(target_os = "linux")'.dependencies] +procfs.workspace = true + [dev-dependencies] rand = "0.8" rand_distr = "0.4.3" diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index cb9914e5de..b57fd9f33b 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -31,6 +31,8 @@ pub use wrappers::{CountedReader, CountedWriter}; mod hll; pub mod metric_vec_duration; pub use hll::{HyperLogLog, HyperLogLogVec}; +#[cfg(target_os = "linux")] +pub mod more_process_metrics; pub type UIntGauge = GenericGauge; pub type UIntGaugeVec = GenericGaugeVec; diff --git a/libs/metrics/src/more_process_metrics.rs b/libs/metrics/src/more_process_metrics.rs new file mode 100644 index 0000000000..920724fdec --- /dev/null +++ b/libs/metrics/src/more_process_metrics.rs @@ -0,0 +1,54 @@ +//! process metrics that the [`::prometheus`] crate doesn't provide. + +// This module has heavy inspiration from the prometheus crate's `process_collector.rs`. + +use crate::UIntGauge; + +pub struct Collector { + descs: Vec, + vmlck: crate::UIntGauge, +} + +const NMETRICS: usize = 1; + +impl prometheus::core::Collector for Collector { + fn desc(&self) -> Vec<&prometheus::core::Desc> { + self.descs.iter().collect() + } + + fn collect(&self) -> Vec { + let Ok(myself) = procfs::process::Process::myself() else { + return vec![]; + }; + let mut mfs = Vec::with_capacity(NMETRICS); + if let Ok(status) = myself.status() { + if let Some(vmlck) = status.vmlck { + self.vmlck.set(vmlck); + mfs.extend(self.vmlck.collect()) + } + } + mfs + } +} + +impl Collector { + pub fn new() -> Self { + let mut descs = Vec::new(); + + let vmlck = + UIntGauge::new("libmetrics_process_status_vmlck", "/proc/self/status vmlck").unwrap(); + descs.extend( + prometheus::core::Collector::desc(&vmlck) + .into_iter() + .cloned(), + ); + + Self { descs, vmlck } + } +} + +impl Default for Collector { + fn default() -> Self { + Self::new() + } +} diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index eaddcb4607..7a93830c14 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -272,6 +272,8 @@ fn start_pageserver( ); set_build_info_metric(GIT_VERSION, BUILD_TAG); set_launch_timestamp_metric(launch_ts); + #[cfg(target_os = "linux")] + metrics::register_internal(Box::new(metrics::more_process_metrics::Collector::new())).unwrap(); pageserver::preinitialize_metrics(); // If any failpoints were set from FAILPOINTS environment variable, From 9a017778a9f89d5adfb6869a883ee2532dcaf13a Mon Sep 17 00:00:00 2001 From: Andreas Scherbaum Date: Thu, 8 Feb 2024 00:48:31 +0100 Subject: [PATCH 12/81] Update copyright notice, set it to current year (#6671) ## Problem Copyright notice is outdated ## Summary of changes Replace the initial year `2022` with `2022 - 2024`, after brief discussion with Stas about the format Co-authored-by: Andreas Scherbaum --- NOTICE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NOTICE b/NOTICE index c13dc2f0b3..52fc751c41 100644 --- a/NOTICE +++ b/NOTICE @@ -1,5 +1,5 @@ Neon -Copyright 2022 Neon Inc. +Copyright 2022 - 2024 Neon Inc. The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license. See vendor/postgres-vX/COPYRIGHT for details. From c52495774d5151db63059515a524621660236f75 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Thu, 8 Feb 2024 00:58:54 +0100 Subject: [PATCH 13/81] tokio-epoll-uring: expose its metrics in pageserver's `/metrics` (#6672) context: https://github.com/neondatabase/neon/issues/6667 --- Cargo.lock | 4 +- pageserver/src/bin/pageserver.rs | 4 ++ pageserver/src/metrics.rs | 66 ++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf1ecfa89d..30e233ecc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5739,7 +5739,7 @@ dependencies = [ [[package]] name = "tokio-epoll-uring" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" dependencies = [ "futures", "nix 0.26.4", @@ -6264,7 +6264,7 @@ dependencies = [ [[package]] name = "uring-common" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" dependencies = [ "io-uring", "libc", diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 7a93830c14..2f172bd384 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -274,6 +274,10 @@ fn start_pageserver( set_launch_timestamp_metric(launch_ts); #[cfg(target_os = "linux")] metrics::register_internal(Box::new(metrics::more_process_metrics::Collector::new())).unwrap(); + metrics::register_internal(Box::new( + pageserver::metrics::tokio_epoll_uring::Collector::new(), + )) + .unwrap(); pageserver::preinitialize_metrics(); // If any failpoints were set from FAILPOINTS environment variable, diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 489ec58e62..98c98ef6e7 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2400,6 +2400,72 @@ impl>, O, E> Future for MeasuredRemoteOp { } } +pub mod tokio_epoll_uring { + use metrics::UIntGauge; + + pub struct Collector { + descs: Vec, + systems_created: UIntGauge, + systems_destroyed: UIntGauge, + } + + const NMETRICS: usize = 2; + + impl metrics::core::Collector for Collector { + fn desc(&self) -> Vec<&metrics::core::Desc> { + self.descs.iter().collect() + } + + fn collect(&self) -> Vec { + let mut mfs = Vec::with_capacity(NMETRICS); + let tokio_epoll_uring::metrics::Metrics { + systems_created, + systems_destroyed, + } = tokio_epoll_uring::metrics::global(); + self.systems_created.set(systems_created); + mfs.extend(self.systems_created.collect()); + self.systems_destroyed.set(systems_destroyed); + mfs.extend(self.systems_destroyed.collect()); + mfs + } + } + + impl Collector { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let mut descs = Vec::new(); + + let systems_created = UIntGauge::new( + "pageserver_tokio_epoll_uring_systems_created", + "counter of tokio-epoll-uring systems that were created", + ) + .unwrap(); + descs.extend( + metrics::core::Collector::desc(&systems_created) + .into_iter() + .cloned(), + ); + + let systems_destroyed = UIntGauge::new( + "pageserver_tokio_epoll_uring_systems_destroyed", + "counter of tokio-epoll-uring systems that were destroyed", + ) + .unwrap(); + descs.extend( + metrics::core::Collector::desc(&systems_destroyed) + .into_iter() + .cloned(), + ); + + Self { + descs, + systems_created, + systems_destroyed, + } + } + } +} + pub fn preinitialize_metrics() { // Python tests need these and on some we do alerting. // From c63e3e7e84c2dd9c9792619cc4fee15b07cfe7d7 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:57:05 +0100 Subject: [PATCH 14/81] Proxy: improve http-pool (#6577) ## Problem The password check logic for the sql-over-http is a bit non-intuitive. ## Summary of changes 1. Perform scram auth using the same logic as for websocket cleartext password. 2. Split establish connection logic and connection pool. 3. Parallelize param parsing logic with authentication + wake compute. 4. Limit the total number of clients --- Cargo.lock | 1 + proxy/Cargo.toml | 1 + proxy/src/auth/backend.rs | 12 + proxy/src/auth/flow.rs | 2 +- proxy/src/bin/proxy.rs | 5 + proxy/src/console/provider/neon.rs | 2 + proxy/src/context.rs | 4 + proxy/src/metrics.rs | 44 +- proxy/src/proxy/connect_compute.rs | 22 +- proxy/src/proxy/tests.rs | 3 + proxy/src/serverless.rs | 41 +- proxy/src/serverless/backend.rs | 157 +++++ proxy/src/serverless/conn_pool.rs | 797 +++++++++++++------------- proxy/src/serverless/json.rs | 28 +- proxy/src/serverless/sql_over_http.rs | 92 ++- test_runner/regress/test_proxy.py | 20 +- 16 files changed, 753 insertions(+), 478 deletions(-) create mode 100644 proxy/src/serverless/backend.rs diff --git a/Cargo.lock b/Cargo.lock index 30e233ecc1..c0c319cd89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4079,6 +4079,7 @@ dependencies = [ "clap", "consumption_metrics", "dashmap", + "env_logger", "futures", "git-version", "hashbrown 0.13.2", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 1247f08ee6..83cab381b3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -19,6 +19,7 @@ chrono.workspace = true clap.workspace = true consumption_metrics.workspace = true dashmap.workspace = true +env_logger.workspace = true futures.workspace = true git-version.workspace = true hashbrown.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 236567163e..fa2782bee3 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -68,6 +68,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_allowed_ips_and_secret( &self, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; + fn get_role_secret(&self) -> Result; } impl std::fmt::Display for BackendType<'_, ()> { @@ -358,6 +359,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } impl BackendType<'_, ComputeUserInfo> { + pub async fn get_role_secret( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + match self { + Console(api, user_info) => api.get_role_secret(ctx, user_info).await, + Link(_) => Ok(Cached::new_uncached(None)), + } + } + pub async fn get_allowed_ips_and_secret( &self, ctx: &mut RequestMonitoring, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 077178d107..c2783e236c 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -167,7 +167,7 @@ impl AuthFlow<'_, S, Scram<'_>> { } } -pub(super) fn validate_password_and_exchange( +pub(crate) fn validate_password_and_exchange( password: &[u8], secret: AuthSecret, ) -> super::Result> { diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3bbb87808d..6974f1a274 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -165,6 +165,10 @@ struct SqlOverHttpArgs { #[clap(long, default_value_t = 20)] sql_over_http_pool_max_conns_per_endpoint: usize, + /// How many connections to pool for each endpoint. Excess connections are discarded + #[clap(long, default_value_t = 20000)] + sql_over_http_pool_max_total_conns: usize, + /// How long pooled connections should remain idle for before closing #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)] sql_over_http_idle_timeout: tokio::time::Duration, @@ -387,6 +391,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { pool_shards: args.sql_over_http.sql_over_http_pool_shards, idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, opt_in: args.sql_over_http.sql_over_http_pool_opt_in, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, }, }; let authentication_config = AuthenticationConfig { diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 0785419790..71b34cb676 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -188,6 +188,7 @@ impl super::Api for Api { ep, Arc::new(auth_info.allowed_ips), ); + ctx.set_project_id(project_id); } // When we just got a secret, we don't need to invalidate it. Ok(Cached::new_uncached(auth_info.secret)) @@ -221,6 +222,7 @@ impl super::Api for Api { self.caches .project_info .insert_allowed_ips(&project_id, ep, allowed_ips.clone()); + ctx.set_project_id(project_id); } Ok(( Cached::new_uncached(allowed_ips), diff --git a/proxy/src/context.rs b/proxy/src/context.rs index e2b0294cd3..fe204534b7 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -89,6 +89,10 @@ impl RequestMonitoring { self.project = Some(x.project_id); } + pub fn set_project_id(&mut self, project_id: ProjectId) { + self.project = Some(project_id); + } + pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { crate::metrics::CONNECTING_ENDPOINTS .with_label_values(&[self.protocol]) diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index fa663d8ff6..e2d96a9c27 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,8 +1,10 @@ use ::metrics::{ exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec, - register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram, - HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec, + register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, + register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec, + IntCounterVec, IntGauge, IntGaugeVec, }; +use metrics::{register_int_counter_pair, IntCounterPair}; use once_cell::sync::Lazy; use tokio::time; @@ -112,6 +114,44 @@ pub static ALLOWED_IPS_NUMBER: Lazy = Lazy::new(|| { .unwrap() }); +pub static HTTP_CONTENT_LENGTH: Lazy = Lazy::new(|| { + register_histogram!( + "proxy_http_conn_content_length_bytes", + "Time it took for proxy to establish a connection to the compute endpoint", + // largest bucket = 3^16 * 0.05ms = 2.15s + exponential_buckets(8.0, 2.0, 20).unwrap() + ) + .unwrap() +}); + +pub static GC_LATENCY: Lazy = Lazy::new(|| { + register_histogram!( + "proxy_http_pool_reclaimation_lag_seconds", + "Time it takes to reclaim unused connection pools", + // 1us -> 65ms + exponential_buckets(1e-6, 2.0, 16).unwrap(), + ) + .unwrap() +}); + +pub static ENDPOINT_POOLS: Lazy = Lazy::new(|| { + register_int_counter_pair!( + "proxy_http_pool_endpoints_registered_total", + "Number of endpoints we have registered pools for", + "proxy_http_pool_endpoints_unregistered_total", + "Number of endpoints we have unregistered pools for", + ) + .unwrap() +}); + +pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { + register_int_gauge!( + "proxy_http_pool_opened_connections", + "Number of opened connections to a database.", + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 58c59dba36..b9346aa743 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -34,21 +34,6 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg node_info.invalidate().config } -/// Try to connect to the compute node once. -#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)] -async fn connect_to_compute_once( - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - timeout: time::Duration, -) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await -} - #[async_trait] pub trait ConnectMechanism { type Connection; @@ -75,13 +60,18 @@ impl ConnectMechanism for TcpMechanism<'_> { type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] async fn connect_once( &self, ctx: &mut RequestMonitoring, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - connect_to_compute_once(ctx, node_info, timeout).await + let allow_self_signed_compute = node_info.allow_self_signed_compute; + node_info + .config + .connect(ctx, allow_self_signed_compute, timeout) + .await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 2000774224..656cabac75 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -478,6 +478,9 @@ impl TestBackend for TestConnectMechanism { { unimplemented!("not used in tests") } + fn get_role_secret(&self) -> Result { + unimplemented!("not used in tests") + } } fn helper_create_cached_node_info() -> CachedNodeInfo { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 7ff93b23b8..58aa925a6a 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -2,6 +2,7 @@ //! //! Handles both SQL over HTTP and SQL over Websockets. +mod backend; mod conn_pool; mod json; mod sql_over_http; @@ -18,11 +19,11 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; -use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; +use crate::serverless::backend::PoolingBackend; use crate::{cancellation::CancelMap, config::ProxyConfig}; use futures::StreamExt; use hyper::{ @@ -54,12 +55,13 @@ pub async fn task_main( info!("websocket server has shut down"); } - let conn_pool = conn_pool::GlobalConnPool::new(config); - - let conn_pool2 = Arc::clone(&conn_pool); - tokio::spawn(async move { - conn_pool2.gc_worker(StdRng::from_entropy()).await; - }); + let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); + { + let conn_pool = Arc::clone(&conn_pool); + tokio::spawn(async move { + conn_pool.gc_worker(StdRng::from_entropy()).await; + }); + } // shutdown the connection pool tokio::spawn({ @@ -73,6 +75,11 @@ pub async fn task_main( } }); + let backend = Arc::new(PoolingBackend { + pool: Arc::clone(&conn_pool), + config, + }); + let tls_config = match config.tls_config.as_ref() { Some(config) => config, None => { @@ -106,7 +113,7 @@ pub async fn task_main( let client_addr = io.client_addr(); let remote_addr = io.inner.remote_addr(); let sni_name = tls.server_name().map(|s| s.to_string()); - let conn_pool = conn_pool.clone(); + let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -119,7 +126,7 @@ pub async fn task_main( Ok(MetricService::new(hyper::service::service_fn( move |req: Request| { let sni_name = sni_name.clone(); - let conn_pool = conn_pool.clone(); + let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -130,8 +137,7 @@ pub async fn task_main( request_handler( req, config, - tls_config, - conn_pool, + backend, ws_connections, cancel_map, session_id, @@ -200,8 +206,7 @@ where async fn request_handler( mut request: Request, config: &'static ProxyConfig, - tls: &'static TlsConfig, - conn_pool: Arc, + backend: Arc, ws_connections: TaskTracker, cancel_map: Arc, session_id: uuid::Uuid, @@ -248,15 +253,7 @@ async fn request_handler( } else if request.uri().path() == "/sql" && request.method() == Method::POST { let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); - sql_over_http::handle( - tls, - &config.http_config, - &mut ctx, - request, - sni_hostname, - conn_pool, - ) - .await + sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs new file mode 100644 index 0000000000..466a74f0ea --- /dev/null +++ b/proxy/src/serverless/backend.rs @@ -0,0 +1,157 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Context; +use async_trait::async_trait; +use tracing::info; + +use crate::{ + auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + compute, + config::ProxyConfig, + console::CachedNodeInfo, + context::RequestMonitoring, + proxy::connect_compute::ConnectMechanism, +}; + +use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME}; + +pub struct PoolingBackend { + pub pool: Arc>, + pub config: &'static ProxyConfig, +} + +impl PoolingBackend { + pub async fn authenticate( + &self, + ctx: &mut RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result { + let user_info = conn_info.user_info.clone(); + let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); + let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; + if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { + return Err(AuthError::ip_address_not_allowed()); + } + let cached_secret = match maybe_secret { + Some(secret) => secret, + None => backend.get_role_secret(ctx).await?, + }; + + let secret = match cached_secret.value.clone() { + Some(secret) => secret, + None => { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::auth_failed(&*user_info.user)); + } + }; + let auth_outcome = + crate::auth::validate_password_and_exchange(conn_info.password.as_bytes(), secret)?; + match auth_outcome { + crate::sasl::Outcome::Success(key) => Ok(key), + crate::sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + Err(AuthError::auth_failed(&*conn_info.user_info.user)) + } + } + } + + // Wake up the destination if needed. Code here is a bit involved because + // we reuse the code from the usual proxy and we need to prepare few structures + // that this code expects. + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + pub async fn connect_to_compute( + &self, + ctx: &mut RequestMonitoring, + conn_info: ConnInfo, + keys: ComputeCredentialKeys, + force_new: bool, + ) -> anyhow::Result> { + let maybe_client = if !force_new { + info!("pool: looking for an existing connection"); + self.pool.get(ctx, &conn_info).await? + } else { + info!("pool: pool is disabled"); + None + }; + + if let Some(client) = maybe_client { + return Ok(client); + } + let conn_id = uuid::Uuid::new_v4(); + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + ctx.set_application(Some(APP_NAME)); + let backend = self + .config + .auth_backend + .as_ref() + .map(|_| conn_info.user_info.clone()); + + let mut node_info = backend + .wake_compute(ctx) + .await? + .context("missing cache entry from wake_compute")?; + + match keys { + #[cfg(any(test, feature = "testing"))] + ComputeCredentialKeys::Password(password) => node_info.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), + }; + + ctx.set_project(node_info.aux.clone()); + + crate::proxy::connect_compute::connect_to_compute( + ctx, + &TokioMechanism { + conn_id, + conn_info, + pool: self.pool.clone(), + }, + node_info, + &backend, + ) + .await + } +} + +struct TokioMechanism { + pool: Arc>, + conn_info: ConnInfo, + conn_id: uuid::Uuid, +} + +#[async_trait] +impl ConnectMechanism for TokioMechanism { + type Connection = Client; + type ConnectError = tokio_postgres::Error; + type Error = anyhow::Error; + + async fn connect_once( + &self, + ctx: &mut RequestMonitoring, + node_info: &CachedNodeInfo, + timeout: Duration, + ) -> Result { + let mut config = (*node_info.config).clone(); + let config = config + .user(&self.conn_info.user_info.user) + .password(&*self.conn_info.password) + .dbname(&self.conn_info.dbname) + .connect_timeout(timeout); + + let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + + tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); + Ok(poll_client( + self.pool.clone(), + ctx, + self.conn_info.clone(), + client, + connection, + self.conn_id, + node_info.aux.clone(), + )) + } + + fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} +} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 312fa2b36f..a7b2c532d2 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,15 +1,7 @@ -use anyhow::Context; -use async_trait::async_trait; use dashmap::DashMap; use futures::{future::poll_fn, Future}; -use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard}; -use once_cell::sync::Lazy; +use metrics::IntCounterPairGuard; use parking_lot::RwLock; -use pbkdf2::{ - password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString}, - Params, Pbkdf2, -}; -use prometheus::{exponential_buckets, register_histogram, Histogram}; use rand::Rng; use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; @@ -21,19 +13,17 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; -use tokio::time::{self, Instant}; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; +use tokio::time::Instant; +use tokio_postgres::tls::NoTlsStream; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use crate::console::messages::MetricsAuxInfo; +use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL}; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{ - auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list}, - console::{self, messages::MetricsAuxInfo}, - context::RequestMonitoring, - metrics::NUM_DB_CONNECTIONS_GAUGE, - proxy::connect_compute::ConnectMechanism, - usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, + auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE, DbName, EndpointCacheKey, RoleName, }; -use crate::{compute, config}; use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; @@ -72,39 +62,51 @@ impl fmt::Display for ConnInfo { } } -struct ConnPoolEntry { - conn: ClientInner, +struct ConnPoolEntry { + conn: ClientInner, _last_access: std::time::Instant, } // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. -pub struct EndpointConnPool { - pools: HashMap<(DbName, RoleName), DbUserConnPool>, +pub struct EndpointConnPool { + pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, _guard: IntCounterPairGuard, + global_connections_count: Arc, + global_pool_size_max_conns: usize, } -impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option { +impl EndpointConnPool { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { let Self { - pools, total_conns, .. + pools, + total_conns, + global_connections_count, + .. } = self; - pools - .get_mut(&db_user) - .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) + pools.get_mut(&db_user).and_then(|pool_entries| { + pool_entries.get_conn_entry(total_conns, global_connections_count.clone()) + }) } fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { let Self { - pools, total_conns, .. + pools, + total_conns, + global_connections_count, + .. } = self; if let Some(pool) = pools.get_mut(&db_user) { let old_len = pool.conns.len(); pool.conns.retain(|conn| conn.conn.conn_id != conn_id); let new_len = pool.conns.len(); let removed = old_len - new_len; + if removed > 0 { + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); + } *total_conns -= removed; removed > 0 } else { @@ -112,13 +114,27 @@ impl EndpointConnPool { } } - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> { + fn put( + pool: &RwLock, + conn_info: &ConnInfo, + client: ClientInner, + ) -> anyhow::Result<()> { let conn_id = client.conn_id; - if client.inner.is_closed() { + if client.is_closed() { info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); return Ok(()); } + let global_max_conn = pool.read().global_pool_size_max_conns; + if pool + .read() + .global_connections_count + .load(atomic::Ordering::Relaxed) + >= global_max_conn + { + info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); + return Ok(()); + } // return connection to the pool let mut returned = false; @@ -127,18 +143,19 @@ impl EndpointConnPool { let mut pool = pool.write(); if pool.total_conns < pool.max_conns { - // we create this db-user entry in get, so it should not be None - if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) { - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); + let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); - returned = true; - per_db_size = pool_entries.conns.len(); + returned = true; + per_db_size = pool_entries.conns.len(); - pool.total_conns += 1; - } + pool.total_conns += 1; + pool.global_connections_count + .fetch_add(1, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc(); } pool.total_conns @@ -155,49 +172,61 @@ impl EndpointConnPool { } } -/// 4096 is the number of rounds that SCRAM-SHA-256 recommends. -/// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway. -/// -/// Still takes 1.4ms to hash on my hardware. -/// We don't want to ruin the latency improvements of using the pool by making password verification take too long -const PARAMS: Params = Params { - rounds: 4096, - output_length: 32, -}; - -#[derive(Default)] -pub struct DbUserConnPool { - conns: Vec, - password_hash: Option, +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if self.total_conns > 0 { + self.global_connections_count + .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64); + } + } } -impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) { +pub struct DbUserConnPool { + conns: Vec>, +} + +impl Default for DbUserConnPool { + fn default() -> Self { + Self { conns: Vec::new() } + } +} + +impl DbUserConnPool { + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { let old_len = self.conns.len(); - self.conns.retain(|conn| !conn.conn.inner.is_closed()); + self.conns.retain(|conn| !conn.conn.is_closed()); let new_len = self.conns.len(); let removed = old_len - new_len; *conns -= removed; + removed } - fn get_conn_entry(&mut self, conns: &mut usize) -> Option { - self.clear_closed_clients(conns); + fn get_conn_entry( + &mut self, + conns: &mut usize, + global_connections_count: Arc, + ) -> Option> { + let mut removed = self.clear_closed_clients(conns); let conn = self.conns.pop(); if conn.is_some() { *conns -= 1; + removed += 1; } + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); conn } } -pub struct GlobalConnPool { +pub struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>>, /// Number of endpoint-connection pools /// @@ -206,7 +235,10 @@ pub struct GlobalConnPool { /// It's only used for diagnostics. global_pool_size: AtomicUsize, - proxy_config: &'static crate::config::ProxyConfig, + /// Total number of connections in the pool + global_connections_count: Arc, + + config: &'static crate::config::HttpConfig, } #[derive(Debug, Clone, Copy)] @@ -224,45 +256,39 @@ pub struct GlobalConnPoolOptions { pub idle_timeout: Duration, pub opt_in: bool, + + // Total number of connections in the pool. + pub max_total_conns: usize, } -pub static GC_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "proxy_http_pool_reclaimation_lag_seconds", - "Time it takes to reclaim unused connection pools", - // 1us -> 65ms - exponential_buckets(1e-6, 2.0, 16).unwrap(), - ) - .unwrap() -}); - -pub static ENDPOINT_POOLS: Lazy = Lazy::new(|| { - register_int_counter_pair!( - "proxy_http_pool_endpoints_registered_total", - "Number of endpoints we have registered pools for", - "proxy_http_pool_endpoints_unregistered_total", - "Number of endpoints we have unregistered pools for", - ) - .unwrap() -}); - -impl GlobalConnPool { - pub fn new(config: &'static crate::config::ProxyConfig) -> Arc { - let shards = config.http_config.pool_options.pool_shards; +impl GlobalConnPool { + pub fn new(config: &'static crate::config::HttpConfig) -> Arc { + let shards = config.pool_options.pool_shards; Arc::new(Self { global_pool: DashMap::with_shard_amount(shards), global_pool_size: AtomicUsize::new(0), - proxy_config: config, + config, + global_connections_count: Arc::new(AtomicUsize::new(0)), }) } + #[cfg(test)] + pub fn get_global_connections_count(&self) -> usize { + self.global_connections_count + .load(atomic::Ordering::Relaxed) + } + + pub fn get_idle_timeout(&self) -> Duration { + self.config.pool_options.idle_timeout + } + pub fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); } pub async fn gc_worker(&self, mut rng: impl Rng) { - let epoch = self.proxy_config.http_config.pool_options.gc_epoch; + let epoch = self.config.pool_options.gc_epoch; let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); loop { interval.tick().await; @@ -280,6 +306,7 @@ impl GlobalConnPool { let timer = GC_LATENCY.start_timer(); let current_len = shard.len(); + let mut clients_removed = 0; shard.retain(|endpoint, x| { // if the current endpoint pool is unique (no other strong or weak references) // then it is currently not in use by any connections. @@ -289,9 +316,9 @@ impl GlobalConnPool { } = pool.get_mut(); // ensure that closed clients are removed - pools - .iter_mut() - .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns)); + pools.iter_mut().for_each(|(_, db_pool)| { + clients_removed += db_pool.clear_closed_clients(total_conns); + }); // we only remove this pool if it has no active connections if *total_conns == 0 { @@ -302,10 +329,20 @@ impl GlobalConnPool { true }); + let new_len = shard.len(); drop(shard); timer.observe_duration(); + // Do logging outside of the lock. + if clients_removed > 0 { + let size = self + .global_connections_count + .fetch_sub(clients_removed, atomic::Ordering::Relaxed) + - clients_removed; + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64); + info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); + } let removed = current_len - new_len; if removed > 0 { @@ -320,61 +357,24 @@ impl GlobalConnPool { pub async fn get( self: &Arc, ctx: &mut RequestMonitoring, - conn_info: ConnInfo, - force_new: bool, - ) -> anyhow::Result { - let mut client: Option = None; + conn_info: &ConnInfo, + ) -> anyhow::Result>> { + let mut client: Option> = None; - let mut hash_valid = false; - let mut endpoint_pool = Weak::new(); - if !force_new { - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - endpoint_pool = Arc::downgrade(&pool); - let mut hash = None; - - // find a pool entry by (dbname, username) if exists - { - let pool = pool.read(); - if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) { - if !pool_entries.conns.is_empty() { - hash = pool_entries.password_hash.clone(); - } - } - } - - // a connection exists in the pool, verify the password hash - if let Some(hash) = hash { - let pw = conn_info.password.clone(); - let validate = tokio::task::spawn_blocking(move || { - Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash()) - }) - .await?; - - // if the hash is invalid, don't error - // we will continue with the regular connection flow - if validate.is_ok() { - hash_valid = true; - if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) { - client = Some(entry.conn) - } - } - } + let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); + if let Some(entry) = endpoint_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn) } + let endpoint_pool = Arc::downgrade(&endpoint_pool); // ok return cached connection if found and establish a new one otherwise - let new_client = if let Some(client) = client { - ctx.set_project(client.aux.clone()); - if client.inner.is_closed() { - let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one"); - connect_to_compute( - self.proxy_config, - ctx, - &conn_info, - conn_id, - endpoint_pool.clone(), - ) - .await + if let Some(client) = client { + if client.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); } else { info!("pool: reusing connection '{conn_info}'"); client.session.send(ctx.session_id)?; @@ -384,67 +384,16 @@ impl GlobalConnPool { ); ctx.latency_timer.pool_hit(); ctx.latency_timer.success(); - return Ok(Client::new(client, conn_info, endpoint_pool).await); + return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } - } else { - let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - connect_to_compute( - self.proxy_config, - ctx, - &conn_info, - conn_id, - endpoint_pool.clone(), - ) - .await - }; - if let Ok(client) = &new_client { - tracing::Span::current().record( - "pid", - &tracing::field::display(client.inner.get_process_id()), - ); } - - match &new_client { - // clear the hash. it's no longer valid - // TODO: update tokio-postgres fork to allow access to this error kind directly - Err(err) - if hash_valid && err.to_string().contains("password authentication failed") => - { - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - let mut pool = pool.write(); - if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) { - entry.password_hash = None; - } - } - // new password is valid and we should insert/update it - Ok(_) if !force_new && !hash_valid => { - let pw = conn_info.password.clone(); - let new_hash = tokio::task::spawn_blocking(move || { - let salt = SaltString::generate(rand::rngs::OsRng); - Pbkdf2 - .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt) - .map(|s| s.serialize()) - }) - .await??; - - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - let mut pool = pool.write(); - pool.pools - .entry(conn_info.db_and_user()) - .or_default() - .password_hash = Some(new_hash); - } - _ => {} - } - let new_client = new_client?; - Ok(Client::new(new_client, conn_info, endpoint_pool).await) + Ok(None) } fn get_or_create_endpoint_pool( - &self, + self: &Arc, endpoint: &EndpointCacheKey, - ) -> Arc> { + ) -> Arc>> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); @@ -454,12 +403,10 @@ impl GlobalConnPool { let new_pool = Arc::new(RwLock::new(EndpointConnPool { pools: HashMap::new(), total_conns: 0, - max_conns: self - .proxy_config - .http_config - .pool_options - .max_conns_per_endpoint, + max_conns: self.config.pool_options.max_conns_per_endpoint, _guard: ENDPOINT_POOLS.guard(), + global_connections_count: self.global_connections_count.clone(), + global_pool_size_max_conns: self.config.pool_options.max_total_conns, })); // find or create a pool for this endpoint @@ -488,196 +435,128 @@ impl GlobalConnPool { } } -struct TokioMechanism<'a> { - pool: Weak>, - conn_info: &'a ConnInfo, - conn_id: uuid::Uuid, - idle: Duration, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism<'_> { - type Connection = ClientInner; - type ConnectError = tokio_postgres::Error; - type Error = anyhow::Error; - - async fn connect_once( - &self, - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - timeout: time::Duration, - ) -> Result { - connect_to_compute_once( - ctx, - node_info, - self.conn_info, - timeout, - self.conn_id, - self.pool.clone(), - self.idle, - ) - .await - } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} -} - -// Wake up the destination if needed. Code here is a bit involved because -// we reuse the code from the usual proxy and we need to prepare few structures -// that this code expects. -#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] -async fn connect_to_compute( - config: &config::ProxyConfig, +pub fn poll_client( + global_pool: Arc>, ctx: &mut RequestMonitoring, - conn_info: &ConnInfo, + conn_info: ConnInfo, + client: C, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, - pool: Weak>, -) -> anyhow::Result { - ctx.set_application(Some(APP_NAME)); - let backend = config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - if !config.disable_ip_check_for_http { - let (allowed_ips, _) = backend.get_allowed_ips_and_secret(ctx).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed().into()); - } - } - let node_info = backend - .wake_compute(ctx) - .await? - .context("missing cache entry from wake_compute")?; - - ctx.set_project(node_info.aux.clone()); - - crate::proxy::connect_compute::connect_to_compute( - ctx, - &TokioMechanism { - conn_id, - conn_info, - pool, - idle: config.http_config.pool_options.idle_timeout, - }, - node_info, - &backend, - ) - .await -} - -async fn connect_to_compute_once( - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - conn_info: &ConnInfo, - timeout: time::Duration, - conn_id: uuid::Uuid, - pool: Weak>, - idle: Duration, -) -> Result { - let mut config = (*node_info.config).clone(); - let mut session = ctx.session_id; - - let (client, mut connection) = config - .user(&conn_info.user_info.user) - .password(&*conn_info.password) - .dbname(&conn_info.dbname) - .connect_timeout(timeout) - .connect(tokio_postgres::NoTls) - .await?; - + aux: MetricsAuxInfo, +) -> Client { let conn_gauge = NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); - - tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); - - let (tx, mut rx) = tokio::sync::watch::channel(session); + let mut session_id = ctx.session_id; + let (tx, mut rx) = tokio::sync::watch::channel(session_id); let span = info_span!(parent: None, "connection", %conn_id); span.in_scope(|| { - info!(%conn_info, %session, "new connection"); + info!(%conn_info, %session_id, "new connection"); }); + let pool = + Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + let pool_clone = pool.clone(); let db_user = conn_info.db_and_user(); + let idle = global_pool.get_idle_timeout(); tokio::spawn( - async move { - let _conn_gauge = conn_gauge; - let mut idle_timeout = pin!(tokio::time::sleep(idle)); - poll_fn(move |cx| { - if matches!(rx.has_changed(), Ok(true)) { - session = *rx.borrow_and_update(); - info!(%session, "changed session"); - idle_timeout.as_mut().reset(Instant::now() + idle); - } + async move { + let _conn_gauge = conn_gauge; + let mut idle_timeout = pin!(tokio::time::sleep(idle)); + poll_fn(move |cx| { + if matches!(rx.has_changed(), Ok(true)) { + session_id = *rx.borrow_and_update(); + info!(%session_id, "changed session"); + idle_timeout.as_mut().reset(Instant::now() + idle); + } - // 5 minute idle connection timeout - if idle_timeout.as_mut().poll(cx).is_ready() { - idle_timeout.as_mut().reset(Instant::now() + idle); - info!("connection idle"); - if let Some(pool) = pool.clone().upgrade() { - // remove client from pool - should close the connection if it's idle. - // does nothing if the client is currently checked-out and in-use - if pool.write().remove_client(db_user.clone(), conn_id) { - info!("idle connection removed"); - } - } - } - - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session, "unknown message"); - } - Some(Err(e)) => { - error!(%session, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } - } - - // remove from connection pool + // 5 minute idle connection timeout + if idle_timeout.as_mut().poll(cx).is_ready() { + idle_timeout.as_mut().reset(Instant::now() + idle); + info!("connection idle"); if let Some(pool) = pool.clone().upgrade() { + // remove client from pool - should close the connection if it's idle. + // does nothing if the client is currently checked-out and in-use if pool.write().remove_client(db_user.clone(), conn_id) { - info!("closed connection removed"); + info!("idle connection removed"); } } + } - Poll::Ready(()) - }).await; + loop { + let message = ready!(connection.poll_message(cx)); - } - .instrument(span) - ); + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!(%session_id, "notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); + } + Some(Ok(_)) => { + warn!(%session_id, "unknown message"); + } + Some(Err(e)) => { + error!(%session_id, "connection error: {}", e); + break + } + None => { + info!("connection closed"); + break + } + } + } - Ok(ClientInner { + // remove from connection pool + if let Some(pool) = pool.clone().upgrade() { + if pool.write().remove_client(db_user.clone(), conn_id) { + info!("closed connection removed"); + } + } + + Poll::Ready(()) + }).await; + + } + .instrument(span)); + let inner = ClientInner { inner: client, session: tx, - aux: node_info.aux.clone(), + aux, conn_id, - }) + }; + Client::new(inner, conn_info, pool_clone) } -struct ClientInner { - inner: tokio_postgres::Client, +struct ClientInner { + inner: C, session: tokio::sync::watch::Sender, aux: MetricsAuxInfo, conn_id: uuid::Uuid, } -impl Client { +pub trait ClientInnerExt: Sync + Send + 'static { + fn is_closed(&self) -> bool; + fn get_process_id(&self) -> i32; +} + +impl ClientInnerExt for tokio_postgres::Client { + fn is_closed(&self) -> bool { + self.is_closed() + } + fn get_process_id(&self) -> i32 { + self.get_process_id() + } +} + +impl ClientInner { + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl Client { pub fn metrics(&self) -> Arc { let aux = &self.inner.as_ref().unwrap().aux; USAGE_METRICS.register(Ids { @@ -687,51 +566,46 @@ impl Client { } } -pub struct Client { - conn_id: uuid::Uuid, +pub struct Client { span: Span, - inner: Option, + inner: Option>, conn_info: ConnInfo, - pool: Weak>, + pool: Weak>>, } -pub struct Discard<'a> { +pub struct Discard<'a, C: ClientInnerExt> { conn_id: uuid::Uuid, conn_info: &'a ConnInfo, - pool: &'a mut Weak>, + pool: &'a mut Weak>>, } -impl Client { - pub(self) async fn new( - inner: ClientInner, +impl Client { + pub(self) fn new( + inner: ClientInner, conn_info: ConnInfo, - pool: Weak>, + pool: Weak>>, ) -> Self { Self { - conn_id: inner.conn_id, inner: Some(inner), span: Span::current(), conn_info, pool, } } - pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, - conn_id, conn_info, span: _, } = self; + let inner = inner.as_mut().expect("client inner should not be removed"); ( - &mut inner - .as_mut() - .expect("client inner should not be removed") - .inner, + &mut inner.inner, Discard { pool, conn_info, - conn_id: *conn_id, + conn_id: inner.conn_id, }, ) } @@ -744,7 +618,7 @@ impl Client { } } -impl Discard<'_> { +impl Discard<'_, C> { pub fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { @@ -759,8 +633,8 @@ impl Discard<'_> { } } -impl Deref for Client { - type Target = tokio_postgres::Client; +impl Deref for Client { + type Target = C; fn deref(&self) -> &Self::Target { &self @@ -771,8 +645,8 @@ impl Deref for Client { } } -impl Drop for Client { - fn drop(&mut self) { +impl Client { + fn do_drop(&mut self) -> Option { let conn_info = self.conn_info.clone(); let client = self .inner @@ -781,10 +655,161 @@ impl Drop for Client { if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { let current_span = self.span.clone(); // return connection to the pool - tokio::task::spawn_blocking(move || { + return Some(move || { let _span = current_span.enter(); let _ = EndpointConnPool::put(&conn_pool, &conn_info, client); }); } + None + } +} + +impl Drop for Client { + fn drop(&mut self) { + if let Some(drop) = self.do_drop() { + tokio::task::spawn_blocking(drop); + } + } +} + +#[cfg(test)] +mod tests { + use env_logger; + use std::{mem, sync::atomic::AtomicBool}; + + use super::*; + + struct MockClient(Arc); + impl MockClient { + fn new(is_closed: bool) -> Self { + MockClient(Arc::new(is_closed.into())) + } + } + impl ClientInnerExt for MockClient { + fn is_closed(&self) -> bool { + self.0.load(atomic::Ordering::Relaxed) + } + fn get_process_id(&self) -> i32 { + 0 + } + } + + fn create_inner() -> ClientInner { + create_inner_with(MockClient::new(false)) + } + + fn create_inner_with(client: MockClient) -> ClientInner { + ClientInner { + inner: client, + session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()), + aux: Default::default(), + conn_id: uuid::Uuid::new_v4(), + } + } + + #[tokio::test] + async fn test_pool() { + let _ = env_logger::try_init(); + let config = Box::leak(Box::new(crate::config::HttpConfig { + pool_options: GlobalConnPoolOptions { + max_conns_per_endpoint: 2, + gc_epoch: Duration::from_secs(1), + pool_shards: 2, + idle_timeout: Duration::from_secs(1), + opt_in: false, + max_total_conns: 3, + }, + request_timeout: Duration::from_secs(1), + })); + let pool = GlobalConnPool::new(config); + let conn_info = ConnInfo { + user_info: ComputeUserInfo { + user: "user".into(), + endpoint: "endpoint".into(), + options: Default::default(), + }, + dbname: "dbname".into(), + password: "password".into(), + }; + let ep_pool = + Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + assert_eq!(0, pool.get_global_connections_count()); + client.discard(); + // Discard should not add the connection from the pool. + assert_eq!(0, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + assert_eq!(1, pool.get_global_connections_count()); + } + { + let mut closed_client = Client::new( + create_inner_with(MockClient::new(true)), + conn_info.clone(), + ep_pool.clone(), + ); + closed_client.do_drop().unwrap()(); + mem::forget(closed_client); // drop the client + // The closed client shouldn't be added to the pool. + assert_eq!(1, pool.get_global_connections_count()); + } + let is_closed: Arc = Arc::new(false.into()); + { + let mut client = Client::new( + create_inner_with(MockClient(is_closed.clone())), + conn_info.clone(), + ep_pool.clone(), + ); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client should be added to the pool. + assert_eq!(2, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info, ep_pool); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client shouldn't be added to the pool. Because the ep-pool is full. + assert_eq!(2, pool.get_global_connections_count()); + } + + let conn_info = ConnInfo { + user_info: ComputeUserInfo { + user: "user".into(), + endpoint: "endpoint-2".into(), + options: Default::default(), + }, + dbname: "dbname".into(), + password: "password".into(), + }; + let ep_pool = + Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + assert_eq!(3, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client shouldn't be added to the pool. Because the global pool is full. + assert_eq!(3, pool.get_global_connections_count()); + } + + is_closed.store(true, atomic::Ordering::Relaxed); + // Do gc for all shards. + pool.gc(0); + pool.gc(1); + // Closed client should be removed from the pool. + assert_eq!(2, pool.get_global_connections_count()); } } diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 05835b23ce..a089d34040 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -9,23 +9,23 @@ use tokio_postgres::Row; // as parameters. // pub fn json_to_pg_text(json: Vec) -> Vec> { - json.iter() - .map(|value| { - match value { - // special care for nulls - Value::Null => None, + json.iter().map(json_value_to_pg_text).collect() +} - // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), +fn json_value_to_pg_text(value: &Value) -> Option { + match value { + // special care for nulls + Value::Null => None, - // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s.to_string()), + // convert to text with escaping + v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), - // special care for arrays - Value::Array(_) => json_array_to_pg_array(value), - } - }) - .collect() + // avoid escaping here, as we pass this as a parameter + Value::String(s) => Some(s.to_string()), + + // special care for arrays + Value::Array(_) => json_array_to_pg_array(value), + } } // diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 96bf39c915..7092b65f03 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -13,6 +13,7 @@ use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; use serde_json::Value; +use tokio::join; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; use tokio_postgres::GenericClient; @@ -20,6 +21,7 @@ use tokio_postgres::IsolationLevel; use tokio_postgres::ReadyForQueryStatus; use tokio_postgres::Transaction; use tracing::error; +use tracing::info; use tracing::instrument; use url::Url; use utils::http::error::ApiError; @@ -27,22 +29,25 @@ use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; -use crate::config::HttpConfig; +use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; +use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; use crate::RoleName; +use super::backend::PoolingBackend; use super::conn_pool::ConnInfo; -use super::conn_pool::GlobalConnPool; -use super::json::{json_to_pg_text, pg_text_row_to_json}; +use super::json::json_to_pg_text; +use super::json::pg_text_row_to_json; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] struct QueryData { query: String, - params: Vec, + #[serde(deserialize_with = "bytes_to_pg_text")] + params: Vec>, } #[derive(serde::Deserialize)] @@ -69,6 +74,15 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::de::Deserializer<'de>, +{ + // TODO: consider avoiding the allocation here. + let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; + Ok(json_to_pg_text(json)) +} + fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, @@ -171,16 +185,15 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result, sni_hostname: Option, - conn_pool: Arc, + backend: Arc, ) -> Result, ApiError> { let result = tokio::time::timeout( - config.request_timeout, - handle_inner(tls, config, ctx, request, sni_hostname, conn_pool), + config.http_config.request_timeout, + handle_inner(config, ctx, request, sni_hostname, backend), ) .await; let mut response = match result { @@ -265,7 +278,7 @@ pub async fn handle( Err(_) => { let message = format!( "HTTP-Connection timed out, execution time exeeded {} seconds", - config.request_timeout.as_secs() + config.http_config.request_timeout.as_secs() ); error!(message); json_response( @@ -283,22 +296,36 @@ pub async fn handle( #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] async fn handle_inner( - tls: &'static TlsConfig, - config: &'static HttpConfig, + config: &'static ProxyConfig, ctx: &mut RequestMonitoring, request: Request, sni_hostname: Option, - conn_pool: Arc, + backend: Arc, ) -> anyhow::Result> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE - .with_label_values(&["http"]) + .with_label_values(&[ctx.protocol]) .guard(); + info!( + protocol = ctx.protocol, + "handling interactive connection from client" + ); // // Determine the destination and connection params // let headers = request.headers(); - let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?; + // TLS config should be there. + let conn_info = get_conn_info( + ctx, + headers, + sni_hostname, + config.tls_config.as_ref().unwrap(), + )?; + info!( + user = conn_info.user_info.user.as_str(), + project = conn_info.user_info.endpoint.as_str(), + "credentials" + ); // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. @@ -307,8 +334,8 @@ async fn handle_inner( // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in - let allow_pool = - !config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); + let allow_pool = !config.http_config.pool_options.opt_in + || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); // isolation level, read only and deferrable @@ -333,6 +360,8 @@ async fn handle_inner( None => MAX_REQUEST_SIZE + 1, }; drop(paused); + info!(request_content_length, "request size in bytes"); + HTTP_CONTENT_LENGTH.observe(request_content_length as f64); // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body @@ -342,13 +371,28 @@ async fn handle_inner( )); } - // - // Read the query and query params from the request body - // - let body = hyper::body::to_bytes(request.into_body()).await?; - let payload: Payload = serde_json::from_slice(&body)?; + let fetch_and_process_request = async { + let body = hyper::body::to_bytes(request.into_body()) + .await + .map_err(anyhow::Error::from)?; + let payload: Payload = serde_json::from_slice(&body)?; + Ok::(payload) // Adjust error type accordingly + }; - let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?; + let authenticate_and_connect = async { + let keys = backend.authenticate(ctx, &conn_info).await?; + backend + .connect_to_compute(ctx, conn_info, keys, !allow_pool) + .await + }; + + // Run both operations in parallel + let (payload_result, auth_and_connect_result) = + join!(fetch_and_process_request, authenticate_and_connect,); + + // Handle the results + let payload = payload_result?; // Handle errors appropriately + let mut client = auth_and_connect_result?; // Handle errors appropriately let mut response = Response::builder() .status(StatusCode::OK) @@ -482,7 +526,7 @@ async fn query_to_json( raw_output: bool, array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { - let query_params = json_to_pg_text(data.params); + let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; // Manually drain the stream into a vector to leave row_stream hanging diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 1d62f09840..b3b35e446d 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -393,11 +393,11 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") - def get_pid(status: int, pw: str) -> Any: + def get_pid(status: int, pw: str, user="http_auth") -> Any: return static_proxy.http_query( GET_CONNECTION_PID_QUERY, [], - user="http_auth", + user=user, password=pw, expected_code=status, ) @@ -418,20 +418,14 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("alter user http_auth with password 'http2'") - # after password change, should open a new connection to verify it - pid2 = get_pid(200, "http2")["rows"][0]["pid"] - assert pid1 != pid2 + # after password change, shouldn't open a new connection because it checks password in proxy. + rows = get_pid(200, "http2")["rows"] + assert rows == [{"pid": pid1}] time.sleep(0.02) - # query should be on an existing connection - pid = get_pid(200, "http2")["rows"][0]["pid"] - assert pid in [pid1, pid2] - - time.sleep(0.02) - - # old password should not work - res = get_pid(400, "http") + # incorrect user shouldn't reveal that the user doesn't exists + res = get_pid(400, "http", user="http_auth2") assert "password authentication failed for user" in res["message"] From 6c34d4cd147eb3704d8e54b434afee35b7d08704 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Thu, 8 Feb 2024 14:52:04 +0100 Subject: [PATCH 15/81] Proxy: set timeout on establishing connection (#6679) ## Problem There is no timeout on the handshake. ## Summary of changes Set the timeout on the establishing connection. --- proxy/src/bin/proxy.rs | 4 ++++ proxy/src/config.rs | 1 + proxy/src/proxy.rs | 9 +++++---- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 6974f1a274..8fbcb56758 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -88,6 +88,9 @@ struct ProxyCliArgs { /// path to directory with TLS certificates for client postgres connections #[clap(long)] certs_dir: Option, + /// timeout for the TLS handshake + #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] + handshake_timeout: tokio::time::Duration, /// http endpoint to receive periodic metric updates #[clap(long)] metric_collection_endpoint: Option, @@ -411,6 +414,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, endpoint_rps_limit, + handshake_timeout: args.handshake_timeout, // TODO: add this argument region: args.region.clone(), })); diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 2c46458a49..31c9228b35 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,6 +22,7 @@ pub struct ProxyConfig { pub disable_ip_check_for_http: bool, pub endpoint_rps_limit: Vec, pub region: String, + pub handshake_timeout: Duration, } #[derive(Debug)] diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index b68fb26e42..b3b221d3e2 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -194,10 +194,11 @@ pub async fn handle_client( let pause = ctx.latency_timer.pause(); let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map); - let (mut stream, params) = match do_handshake.await? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request - }; + let (mut stream, params) = + match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { + Some(x) => x, + None => return Ok(()), // it's a cancellation request + }; drop(pause); let hostname = mode.hostname(stream.get_ref()); From 43eae17f0d2e84b0c88e34f3fff6bfe515008b89 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Thu, 8 Feb 2024 17:31:15 +0200 Subject: [PATCH 16/81] Drop unused replication slots (#6655) ## Problem See #6626 If there is inactive replication slot then Postgres will not bw able to shrink WAL and delete unused snapshots. If she other active subscription is present, then snapshots created each 15 seconds will overflow AUX_DIR. Setting `max_slot_wal_keep_size` doesn't solve the problem, because even small WAL segment will be enough to overflow AUX_DIR if there is no other activity on the system. ## Summary of changes If there are active subscriptions and some logical replication slots are not used during `neon.logical_replication_max_time_lag` interval, then unused slot is dropped. ## Checklist before requesting a review - [ ] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist Co-authored-by: Konstantin Knizhnik --- pgxn/neon/neon.c | 133 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index b930fdb3ca..799f88751c 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -11,16 +11,23 @@ #include "postgres.h" #include "fmgr.h" +#include "miscadmin.h" #include "access/xact.h" #include "access/xlog.h" #include "storage/buf_internals.h" #include "storage/bufmgr.h" #include "catalog/pg_type.h" +#include "postmaster/bgworker.h" +#include "postmaster/interrupt.h" +#include "replication/slot.h" #include "replication/walsender.h" +#include "storage/procsignal.h" +#include "tcop/tcopprot.h" #include "funcapi.h" #include "access/htup_details.h" #include "utils/pg_lsn.h" #include "utils/guc.h" +#include "utils/wait_event.h" #include "neon.h" #include "walproposer.h" @@ -30,6 +37,130 @@ PG_MODULE_MAGIC; void _PG_init(void); +static int logical_replication_max_time_lag = 3600; + +static void +InitLogicalReplicationMonitor(void) +{ + BackgroundWorker bgw; + + DefineCustomIntVariable( + "neon.logical_replication_max_time_lag", + "Threshold for dropping unused logical replication slots", + NULL, + &logical_replication_max_time_lag, + 3600, 0, INT_MAX, + PGC_SIGHUP, + GUC_UNIT_S, + NULL, NULL, NULL); + + memset(&bgw, 0, sizeof(bgw)); + bgw.bgw_flags = BGWORKER_SHMEM_ACCESS; + bgw.bgw_start_time = BgWorkerStart_RecoveryFinished; + snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon"); + snprintf(bgw.bgw_function_name, BGW_MAXLEN, "LogicalSlotsMonitorMain"); + snprintf(bgw.bgw_name, BGW_MAXLEN, "Logical replication monitor"); + snprintf(bgw.bgw_type, BGW_MAXLEN, "Logical replication monitor"); + bgw.bgw_restart_time = 5; + bgw.bgw_notify_pid = 0; + bgw.bgw_main_arg = (Datum) 0; + + RegisterBackgroundWorker(&bgw); +} + +typedef struct +{ + NameData name; + bool dropped; + XLogRecPtr confirmed_flush_lsn; + TimestampTz last_updated; +} SlotStatus; + +/* + * Unused logical replication slots pins WAL and prevents deletion of snapshots. + */ +PGDLLEXPORT void +LogicalSlotsMonitorMain(Datum main_arg) +{ + SlotStatus* slots; + TimestampTz now, last_checked; + + /* Establish signal handlers. */ + pqsignal(SIGUSR1, procsignal_sigusr1_handler); + pqsignal(SIGHUP, SignalHandlerForConfigReload); + pqsignal(SIGTERM, die); + + BackgroundWorkerUnblockSignals(); + + slots = (SlotStatus*)calloc(max_replication_slots, sizeof(SlotStatus)); + last_checked = GetCurrentTimestamp(); + + for (;;) + { + (void) WaitLatch(MyLatch, + WL_LATCH_SET | WL_EXIT_ON_PM_DEATH | WL_TIMEOUT, + logical_replication_max_time_lag*1000/2, + PG_WAIT_EXTENSION); + ResetLatch(MyLatch); + CHECK_FOR_INTERRUPTS(); + + now = GetCurrentTimestamp(); + + if (now - last_checked > logical_replication_max_time_lag*USECS_PER_SEC) + { + int n_active_slots = 0; + last_checked = now; + + LWLockAcquire(ReplicationSlotControlLock, LW_SHARED); + for (int i = 0; i < max_replication_slots; i++) + { + ReplicationSlot *s = &ReplicationSlotCtl->replication_slots[i]; + + /* Consider only logical repliction slots */ + if (!s->in_use || !SlotIsLogical(s)) + continue; + + if (s->active_pid != 0) + { + n_active_slots += 1; + continue; + } + + /* Check if there was some activity with the slot since last check */ + if (s->data.confirmed_flush != slots[i].confirmed_flush_lsn) + { + slots[i].confirmed_flush_lsn = s->data.confirmed_flush; + slots[i].last_updated = now; + } + else if (now - slots[i].last_updated > logical_replication_max_time_lag*USECS_PER_SEC) + { + slots[i].name = s->data.name; + slots[i].dropped = true; + } + } + LWLockRelease(ReplicationSlotControlLock); + + /* + * If there are no active subscriptions, then no new snapshots are generated + * and so no need to force slot deletion. + */ + if (n_active_slots != 0) + { + for (int i = 0; i < max_replication_slots; i++) + { + if (slots[i].dropped) + { + elog(LOG, "Drop logical replication slot because it was not update more than %ld seconds", + (now - slots[i].last_updated)/USECS_PER_SEC); + ReplicationSlotDrop(slots[i].name.data, true); + slots[i].dropped = false; + } + } + } + } + } +} + void _PG_init(void) { @@ -44,6 +175,8 @@ _PG_init(void) pg_init_libpagestore(); pg_init_walproposer(); + InitLogicalReplicationMonitor(); + InitControlPlaneConnector(); pg_init_extension_server(); From af91a28936eef0b1e5149dc71d92394a89410372 Mon Sep 17 00:00:00 2001 From: John Spray Date: Thu, 8 Feb 2024 15:35:13 +0000 Subject: [PATCH 17/81] pageserver: shard splitting (#6379) ## Problem One doesn't know at tenant creation time how large the tenant will grow. We need to be able to dynamically adjust the shard count at runtime. This is implemented as "splitting" of shards into smaller child shards, which cover a subset of the keyspace that the parent covered. Refer to RFC: https://github.com/neondatabase/neon/pull/6358 Part of epic: #6278 ## Summary of changes This PR implements the happy path (does not cleanly recover from a crash mid-split, although won't lose any data), without any optimizations (e.g. child shards re-download their own copies of layers that the parent shard already had on local disk) - Add `/v1/tenant/:tenant_shard_id/shard_split` API to pageserver: this copies the shard's index to the child shards' paths, instantiates child `Tenant` object, and tears down parent `Tenant` object. - Add `splitting` column to `tenant_shards` table. This is written into an existing migration because we haven't deployed yet, so don't need to cleanly upgrade. - Add `/control/v1/tenant/:tenant_id/shard_split` API to attachment_service, - Add `test_sharding_split_smoke` test. This covers the happy path: future PRs will add tests that exercise failure cases. --- Dockerfile | 5 + .../up.sql | 1 + control_plane/attachment_service/src/http.rs | 19 +- .../attachment_service/src/persistence.rs | 102 +++++- .../src/persistence/split_state.rs | 46 +++ .../attachment_service/src/schema.rs | 1 + .../attachment_service/src/service.rs | 333 +++++++++++++++++- .../attachment_service/src/tenant_state.rs | 10 + control_plane/src/attachment_service.rs | 21 +- control_plane/src/bin/neon_local.rs | 25 ++ libs/pageserver_api/src/models.rs | 10 + libs/pageserver_api/src/shard.rs | 128 +++++++ pageserver/client/src/mgmt_api.rs | 16 + pageserver/src/http/routes.rs | 27 +- pageserver/src/tenant.rs | 66 ++++ pageserver/src/tenant/mgr.rs | 169 ++++++++- .../tenant/remote_timeline_client/upload.rs | 2 +- test_runner/fixtures/neon_fixtures.py | 2 +- test_runner/regress/test_sharding.py | 129 ++++++- 19 files changed, 1088 insertions(+), 24 deletions(-) create mode 100644 control_plane/attachment_service/src/persistence/split_state.rs diff --git a/Dockerfile b/Dockerfile index bb926643dc..c37f94b981 100644 --- a/Dockerfile +++ b/Dockerfile @@ -100,6 +100,11 @@ RUN mkdir -p /data/.neon/ && chown -R neon:neon /data/.neon/ \ -c "listen_pg_addr='0.0.0.0:6400'" \ -c "listen_http_addr='0.0.0.0:9898'" +# When running a binary that links with libpq, default to using our most recent postgres version. Binaries +# that want a particular postgres version will select it explicitly: this is just a default. +ENV LD_LIBRARY_PATH /usr/local/v16/lib + + VOLUME ["/data"] USER neon EXPOSE 6400 diff --git a/control_plane/attachment_service/migrations/2024-01-07-211257_create_tenant_shards/up.sql b/control_plane/attachment_service/migrations/2024-01-07-211257_create_tenant_shards/up.sql index 585dbc79a0..2ffdae6287 100644 --- a/control_plane/attachment_service/migrations/2024-01-07-211257_create_tenant_shards/up.sql +++ b/control_plane/attachment_service/migrations/2024-01-07-211257_create_tenant_shards/up.sql @@ -7,6 +7,7 @@ CREATE TABLE tenant_shards ( generation INTEGER NOT NULL, generation_pageserver BIGINT NOT NULL, placement_policy VARCHAR NOT NULL, + splitting SMALLINT NOT NULL, -- config is JSON encoded, opaque to the database. config TEXT NOT NULL ); \ No newline at end of file diff --git a/control_plane/attachment_service/src/http.rs b/control_plane/attachment_service/src/http.rs index 049e66fddf..38eecaf7ef 100644 --- a/control_plane/attachment_service/src/http.rs +++ b/control_plane/attachment_service/src/http.rs @@ -3,7 +3,8 @@ use crate::service::{Service, STARTUP_RECONCILE_TIMEOUT}; use hyper::{Body, Request, Response}; use hyper::{StatusCode, Uri}; use pageserver_api::models::{ - TenantCreateRequest, TenantLocationConfigRequest, TimelineCreateRequest, + TenantCreateRequest, TenantLocationConfigRequest, TenantShardSplitRequest, + TimelineCreateRequest, }; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api; @@ -292,6 +293,19 @@ async fn handle_node_configure(mut req: Request) -> Result, json_response(StatusCode::OK, state.service.node_configure(config_req)?) } +async fn handle_tenant_shard_split( + service: Arc, + mut req: Request, +) -> Result, ApiError> { + let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; + let split_req = json_request::(&mut req).await?; + + json_response( + StatusCode::OK, + service.tenant_shard_split(tenant_id, split_req).await?, + ) +} + async fn handle_tenant_shard_migrate( service: Arc, mut req: Request, @@ -391,6 +405,9 @@ pub fn make_router( .put("/control/v1/tenant/:tenant_shard_id/migrate", |r| { tenant_service_handler(r, handle_tenant_shard_migrate) }) + .put("/control/v1/tenant/:tenant_id/shard_split", |r| { + tenant_service_handler(r, handle_tenant_shard_split) + }) // Tenant operations // The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into // this service to manage tenants that actually consist of many tenant shards, as if they are a single entity. diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index db487bcec6..cead540058 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -1,7 +1,9 @@ +pub(crate) mod split_state; use std::collections::HashMap; use std::str::FromStr; use std::time::Duration; +use self::split_state::SplitState; use camino::Utf8Path; use camino::Utf8PathBuf; use control_plane::attachment_service::{NodeAvailability, NodeSchedulingPolicy}; @@ -363,19 +365,101 @@ impl Persistence { Ok(()) } - // TODO: when we start shard splitting, we must durably mark the tenant so that - // on restart, we know that we must go through recovery (list shards that exist - // and pick up where we left off and/or revert to parent shards). + // When we start shard splitting, we must durably mark the tenant so that + // on restart, we know that we must go through recovery. + // + // We create the child shards here, so that they will be available for increment_generation calls + // if some pageserver holding a child shard needs to restart before the overall tenant split is complete. #[allow(dead_code)] - pub(crate) async fn begin_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> { - todo!(); + pub(crate) async fn begin_shard_split( + &self, + old_shard_count: ShardCount, + split_tenant_id: TenantId, + parent_to_children: Vec<(TenantShardId, Vec)>, + ) -> DatabaseResult<()> { + use crate::schema::tenant_shards::dsl::*; + self.with_conn(move |conn| -> DatabaseResult<()> { + conn.transaction(|conn| -> DatabaseResult<()> { + // Mark parent shards as splitting + let updated = diesel::update(tenant_shards) + .filter(tenant_id.eq(split_tenant_id.to_string())) + .filter(shard_count.eq(old_shard_count.0 as i32)) + .set((splitting.eq(1),)) + .execute(conn)?; + if ShardCount(updated.try_into().map_err(|_| DatabaseError::Logical(format!("Overflow existing shard count {} while splitting", updated)))?) != old_shard_count { + // Perhaps a deletion or another split raced with this attempt to split, mutating + // the parent shards that we intend to split. In this case the split request should fail. + return Err(DatabaseError::Logical( + format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {old_shard_count:?})") + )); + } + + // FIXME: spurious clone to sidestep closure move rules + let parent_to_children = parent_to_children.clone(); + + // Insert child shards + for (parent_shard_id, children) in parent_to_children { + let mut parent = crate::schema::tenant_shards::table + .filter(tenant_id.eq(parent_shard_id.tenant_id.to_string())) + .filter(shard_number.eq(parent_shard_id.shard_number.0 as i32)) + .filter(shard_count.eq(parent_shard_id.shard_count.0 as i32)) + .load::(conn)?; + let parent = if parent.len() != 1 { + return Err(DatabaseError::Logical(format!( + "Parent shard {parent_shard_id} not found" + ))); + } else { + parent.pop().unwrap() + }; + for mut shard in children { + // Carry the parent's generation into the child + shard.generation = parent.generation; + + debug_assert!(shard.splitting == SplitState::Splitting); + diesel::insert_into(tenant_shards) + .values(shard) + .execute(conn)?; + } + } + + Ok(()) + })?; + + Ok(()) + }) + .await } - // TODO: when we finish shard splitting, we must atomically clean up the old shards + // When we finish shard splitting, we must atomically clean up the old shards // and insert the new shards, and clear the splitting marker. #[allow(dead_code)] - pub(crate) async fn complete_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> { - todo!(); + pub(crate) async fn complete_shard_split( + &self, + split_tenant_id: TenantId, + old_shard_count: ShardCount, + ) -> DatabaseResult<()> { + use crate::schema::tenant_shards::dsl::*; + self.with_conn(move |conn| -> DatabaseResult<()> { + conn.transaction(|conn| -> QueryResult<()> { + // Drop parent shards + diesel::delete(tenant_shards) + .filter(tenant_id.eq(split_tenant_id.to_string())) + .filter(shard_count.eq(old_shard_count.0 as i32)) + .execute(conn)?; + + // Clear sharding flag + let updated = diesel::update(tenant_shards) + .filter(tenant_id.eq(split_tenant_id.to_string())) + .set((splitting.eq(0),)) + .execute(conn)?; + debug_assert!(updated > 0); + + Ok(()) + })?; + + Ok(()) + }) + .await } } @@ -403,6 +487,8 @@ pub(crate) struct TenantShardPersistence { #[serde(default)] pub(crate) placement_policy: String, #[serde(default)] + pub(crate) splitting: SplitState, + #[serde(default)] pub(crate) config: String, } diff --git a/control_plane/attachment_service/src/persistence/split_state.rs b/control_plane/attachment_service/src/persistence/split_state.rs new file mode 100644 index 0000000000..bce1a75843 --- /dev/null +++ b/control_plane/attachment_service/src/persistence/split_state.rs @@ -0,0 +1,46 @@ +use diesel::pg::{Pg, PgValue}; +use diesel::{ + deserialize::FromSql, deserialize::FromSqlRow, expression::AsExpression, serialize::ToSql, + sql_types::Int2, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, FromSqlRow, AsExpression)] +#[diesel(sql_type = SplitStateSQLRepr)] +#[derive(Deserialize, Serialize)] +pub enum SplitState { + Idle = 0, + Splitting = 1, +} + +impl Default for SplitState { + fn default() -> Self { + Self::Idle + } +} + +type SplitStateSQLRepr = Int2; + +impl ToSql for SplitState { + fn to_sql<'a>( + &'a self, + out: &'a mut diesel::serialize::Output, + ) -> diesel::serialize::Result { + let raw_value: i16 = *self as i16; + let mut new_out = out.reborrow(); + ToSql::::to_sql(&raw_value, &mut new_out) + } +} + +impl FromSql for SplitState { + fn from_sql(pg_value: PgValue) -> diesel::deserialize::Result { + match FromSql::::from_sql(pg_value).map(|v| match v { + 0 => Some(Self::Idle), + 1 => Some(Self::Splitting), + _ => None, + })? { + Some(v) => Ok(v), + None => Err(format!("Invalid SplitState value, was: {:?}", pg_value.as_bytes()).into()), + } + } +} diff --git a/control_plane/attachment_service/src/schema.rs b/control_plane/attachment_service/src/schema.rs index de80fc8f64..db5a957443 100644 --- a/control_plane/attachment_service/src/schema.rs +++ b/control_plane/attachment_service/src/schema.rs @@ -20,6 +20,7 @@ diesel::table! { generation -> Int4, generation_pageserver -> Int8, placement_policy -> Varchar, + splitting -> Int2, config -> Text, } } diff --git a/control_plane/attachment_service/src/service.rs b/control_plane/attachment_service/src/service.rs index 1db1906df8..0ec2b9dc4c 100644 --- a/control_plane/attachment_service/src/service.rs +++ b/control_plane/attachment_service/src/service.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Ordering, collections::{BTreeMap, HashMap}, str::FromStr, sync::Arc, @@ -23,7 +24,7 @@ use pageserver_api::{ models::{ LocationConfig, LocationConfigMode, ShardParameters, TenantConfig, TenantCreateRequest, TenantLocationConfigRequest, TenantLocationConfigResponse, TenantShardLocation, - TimelineCreateRequest, TimelineInfo, + TenantShardSplitRequest, TenantShardSplitResponse, TimelineCreateRequest, TimelineInfo, }, shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId}, }; @@ -40,7 +41,11 @@ use utils::{ use crate::{ compute_hook::{self, ComputeHook}, node::Node, - persistence::{DatabaseError, NodePersistence, Persistence, TenantShardPersistence}, + persistence::{ + split_state::SplitState, DatabaseError, NodePersistence, Persistence, + TenantShardPersistence, + }, + reconciler::attached_location_conf, scheduler::Scheduler, tenant_state::{ IntentState, ObservedState, ObservedStateLocation, ReconcileResult, ReconcileWaitError, @@ -476,6 +481,7 @@ impl Service { generation_pageserver: i64::MAX, placement_policy: serde_json::to_string(&PlacementPolicy::default()).unwrap(), config: serde_json::to_string(&TenantConfig::default()).unwrap(), + splitting: SplitState::default(), }; match self.persistence.insert_tenant_shards(vec![tsp]).await { @@ -718,6 +724,7 @@ impl Service { generation_pageserver: i64::MAX, placement_policy: serde_json::to_string(&placement_policy).unwrap(), config: serde_json::to_string(&create_req.config).unwrap(), + splitting: SplitState::default(), }) .collect(); self.persistence @@ -1100,6 +1107,7 @@ impl Service { self.ensure_attached_wait(tenant_id).await?; // TODO: refuse to do this if shard splitting is in progress + // (https://github.com/neondatabase/neon/issues/6676) let targets = { let locked = self.inner.read().unwrap(); let mut targets = Vec::new(); @@ -1180,6 +1188,7 @@ impl Service { self.ensure_attached_wait(tenant_id).await?; // TODO: refuse to do this if shard splitting is in progress + // (https://github.com/neondatabase/neon/issues/6676) let targets = { let locked = self.inner.read().unwrap(); let mut targets = Vec::new(); @@ -1352,6 +1361,326 @@ impl Service { }) } + pub(crate) async fn tenant_shard_split( + &self, + tenant_id: TenantId, + split_req: TenantShardSplitRequest, + ) -> Result { + let mut policy = None; + let mut shard_ident = None; + + // TODO: put a cancellation token on Service for clean shutdown + let cancel = CancellationToken::new(); + + // A parent shard which will be split + struct SplitTarget { + parent_id: TenantShardId, + node: Node, + child_ids: Vec, + } + + // Validate input, and calculate which shards we will create + let (old_shard_count, targets, compute_hook) = { + let locked = self.inner.read().unwrap(); + + let pageservers = locked.nodes.clone(); + + let mut targets = Vec::new(); + + // In case this is a retry, count how many already-split shards we found + let mut children_found = Vec::new(); + let mut old_shard_count = None; + + for (tenant_shard_id, shard) in + locked.tenants.range(TenantShardId::tenant_range(tenant_id)) + { + match shard.shard.count.0.cmp(&split_req.new_shard_count) { + Ordering::Equal => { + // Already split this + children_found.push(*tenant_shard_id); + continue; + } + Ordering::Greater => { + return Err(ApiError::BadRequest(anyhow::anyhow!( + "Requested count {} but already have shards at count {}", + split_req.new_shard_count, + shard.shard.count.0 + ))); + } + Ordering::Less => { + // Fall through: this shard has lower count than requested, + // is a candidate for splitting. + } + } + + match old_shard_count { + None => old_shard_count = Some(shard.shard.count), + Some(old_shard_count) => { + if old_shard_count != shard.shard.count { + // We may hit this case if a caller asked for two splits to + // different sizes, before the first one is complete. + // e.g. 1->2, 2->4, where the 4 call comes while we have a mixture + // of shard_count=1 and shard_count=2 shards in the map. + return Err(ApiError::Conflict( + "Cannot split, currently mid-split".to_string(), + )); + } + } + } + if policy.is_none() { + policy = Some(shard.policy.clone()); + } + if shard_ident.is_none() { + shard_ident = Some(shard.shard); + } + + if tenant_shard_id.shard_count == ShardCount(split_req.new_shard_count) { + tracing::info!( + "Tenant shard {} already has shard count {}", + tenant_shard_id, + split_req.new_shard_count + ); + continue; + } + + let node_id = + shard + .intent + .attached + .ok_or(ApiError::BadRequest(anyhow::anyhow!( + "Cannot split a tenant that is not attached" + )))?; + + let node = pageservers + .get(&node_id) + .expect("Pageservers may not be deleted while referenced"); + + // TODO: if any reconciliation is currently in progress for this shard, wait for it. + + targets.push(SplitTarget { + parent_id: *tenant_shard_id, + node: node.clone(), + child_ids: tenant_shard_id.split(ShardCount(split_req.new_shard_count)), + }); + } + + if targets.is_empty() { + if children_found.len() == split_req.new_shard_count as usize { + return Ok(TenantShardSplitResponse { + new_shards: children_found, + }); + } else { + // No shards found to split, and no existing children found: the + // tenant doesn't exist at all. + return Err(ApiError::NotFound( + anyhow::anyhow!("Tenant {} not found", tenant_id).into(), + )); + } + } + + (old_shard_count, targets, locked.compute_hook.clone()) + }; + + // unwrap safety: we would have returned above if we didn't find at least one shard to split + let old_shard_count = old_shard_count.unwrap(); + let shard_ident = shard_ident.unwrap(); + let policy = policy.unwrap(); + + // FIXME: we have dropped self.inner lock, and not yet written anything to the database: another + // request could occur here, deleting or mutating the tenant. begin_shard_split checks that the + // parent shards exist as expected, but it would be neater to do the above pre-checks within the + // same database transaction rather than pre-check in-memory and then maybe-fail the database write. + // (https://github.com/neondatabase/neon/issues/6676) + + // Before creating any new child shards in memory or on the pageservers, persist them: this + // enables us to ensure that we will always be able to clean up if something goes wrong. This also + // acts as the protection against two concurrent attempts to split: one of them will get a database + // error trying to insert the child shards. + let mut child_tsps = Vec::new(); + for target in &targets { + let mut this_child_tsps = Vec::new(); + for child in &target.child_ids { + let mut child_shard = shard_ident; + child_shard.number = child.shard_number; + child_shard.count = child.shard_count; + + this_child_tsps.push(TenantShardPersistence { + tenant_id: child.tenant_id.to_string(), + shard_number: child.shard_number.0 as i32, + shard_count: child.shard_count.0 as i32, + shard_stripe_size: shard_ident.stripe_size.0 as i32, + // Note: this generation is a placeholder, [`Persistence::begin_shard_split`] will + // populate the correct generation as part of its transaction, to protect us + // against racing with changes in the state of the parent. + generation: 0, + generation_pageserver: target.node.id.0 as i64, + placement_policy: serde_json::to_string(&policy).unwrap(), + // TODO: get the config out of the map + config: serde_json::to_string(&TenantConfig::default()).unwrap(), + splitting: SplitState::Splitting, + }); + } + + child_tsps.push((target.parent_id, this_child_tsps)); + } + + if let Err(e) = self + .persistence + .begin_shard_split(old_shard_count, tenant_id, child_tsps) + .await + { + match e { + DatabaseError::Query(diesel::result::Error::DatabaseError( + DatabaseErrorKind::UniqueViolation, + _, + )) => { + // Inserting a child shard violated a unique constraint: we raced with another call to + // this function + tracing::warn!("Conflicting attempt to split {tenant_id}: {e}"); + return Err(ApiError::Conflict("Tenant is already splitting".into())); + } + _ => return Err(ApiError::InternalServerError(e.into())), + } + } + + // FIXME: we have now committed the shard split state to the database, so any subsequent + // failure needs to roll it back. We will later wrap this function in logic to roll back + // the split if it fails. + // (https://github.com/neondatabase/neon/issues/6676) + + // TODO: issue split calls concurrently (this only matters once we're splitting + // N>1 shards into M shards -- initially we're usually splitting 1 shard into N). + + for target in &targets { + let SplitTarget { + parent_id, + node, + child_ids, + } = target; + let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref()); + let response = client + .tenant_shard_split( + *parent_id, + TenantShardSplitRequest { + new_shard_count: split_req.new_shard_count, + }, + ) + .await + .map_err(|e| ApiError::Conflict(format!("Failed to split {}: {}", parent_id, e)))?; + + tracing::info!( + "Split {} into {}", + parent_id, + response + .new_shards + .iter() + .map(|s| format!("{:?}", s)) + .collect::>() + .join(",") + ); + + if &response.new_shards != child_ids { + // This should never happen: the pageserver should agree with us on how shard splits work. + return Err(ApiError::InternalServerError(anyhow::anyhow!( + "Splitting shard {} resulted in unexpected IDs: {:?} (expected {:?})", + parent_id, + response.new_shards, + child_ids + ))); + } + } + + // TODO: if the pageserver restarted concurrently with our split API call, + // the actual generation of the child shard might differ from the generation + // we expect it to have. In order for our in-database generation to end up + // correct, we should carry the child generation back in the response and apply it here + // in complete_shard_split (and apply the correct generation in memory) + // (or, we can carry generation in the request and reject the request if + // it doesn't match, but that requires more retry logic on this side) + + self.persistence + .complete_shard_split(tenant_id, old_shard_count) + .await?; + + // Replace all the shards we just split with their children + let mut response = TenantShardSplitResponse { + new_shards: Vec::new(), + }; + let mut child_locations = Vec::new(); + { + let mut locked = self.inner.write().unwrap(); + for target in targets { + let SplitTarget { + parent_id, + node: _node, + child_ids, + } = target; + let (pageserver, generation, config) = { + let old_state = locked + .tenants + .remove(&parent_id) + .expect("It was present, we just split it"); + ( + old_state.intent.attached.unwrap(), + old_state.generation, + old_state.config.clone(), + ) + }; + + locked.tenants.remove(&parent_id); + + for child in child_ids { + let mut child_shard = shard_ident; + child_shard.number = child.shard_number; + child_shard.count = child.shard_count; + + let mut child_observed: HashMap = HashMap::new(); + child_observed.insert( + pageserver, + ObservedStateLocation { + conf: Some(attached_location_conf(generation, &child_shard, &config)), + }, + ); + + let mut child_state = TenantState::new(child, child_shard, policy.clone()); + child_state.intent = IntentState::single(Some(pageserver)); + child_state.observed = ObservedState { + locations: child_observed, + }; + child_state.generation = generation; + child_state.config = config.clone(); + + child_locations.push((child, pageserver)); + + locked.tenants.insert(child, child_state); + response.new_shards.push(child); + } + } + } + + // Send compute notifications for all the new shards + let mut failed_notifications = Vec::new(); + for (child_id, child_ps) in child_locations { + if let Err(e) = compute_hook.notify(child_id, child_ps, &cancel).await { + tracing::warn!("Failed to update compute of {}->{} during split, proceeding anyway to complete split ({e})", + child_id, child_ps); + failed_notifications.push(child_id); + } + } + + // If we failed any compute notifications, make a note to retry later. + if !failed_notifications.is_empty() { + let mut locked = self.inner.write().unwrap(); + for failed in failed_notifications { + if let Some(shard) = locked.tenants.get_mut(&failed) { + shard.pending_compute_notification = true; + } + } + } + + Ok(response) + } + pub(crate) async fn tenant_shard_migrate( &self, tenant_shard_id: TenantShardId, diff --git a/control_plane/attachment_service/src/tenant_state.rs b/control_plane/attachment_service/src/tenant_state.rs index a358e1ff7b..c0ab076a55 100644 --- a/control_plane/attachment_service/src/tenant_state.rs +++ b/control_plane/attachment_service/src/tenant_state.rs @@ -193,6 +193,13 @@ impl IntentState { result } + pub(crate) fn single(node_id: Option) -> Self { + Self { + attached: node_id, + secondary: vec![], + } + } + /// When a node goes offline, we update intents to avoid using it /// as their attached pageserver. /// @@ -286,6 +293,9 @@ impl TenantState { // self.intent refers to pageservers that are offline, and pick other // pageservers if so. + // TODO: respect the splitting bit on tenants: if they are currently splitting then we may not + // change their attach location. + // Build the set of pageservers already in use by this tenant, to avoid scheduling // more work on the same pageservers we're already using. let mut used_pageservers = self.intent.all_pageservers(); diff --git a/control_plane/src/attachment_service.rs b/control_plane/src/attachment_service.rs index a3f832036c..c3e071aa71 100644 --- a/control_plane/src/attachment_service.rs +++ b/control_plane/src/attachment_service.rs @@ -8,7 +8,10 @@ use diesel::{ use diesel_migrations::{HarnessWithOutput, MigrationHarness}; use hyper::Method; use pageserver_api::{ - models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo}, + models::{ + ShardParameters, TenantCreateRequest, TenantShardSplitRequest, TenantShardSplitResponse, + TimelineCreateRequest, TimelineInfo, + }, shard::TenantShardId, }; use pageserver_client::mgmt_api::ResponseErrorMessageExt; @@ -648,7 +651,7 @@ impl AttachmentService { ) -> anyhow::Result { self.dispatch( Method::PUT, - format!("tenant/{tenant_shard_id}/migrate"), + format!("control/v1/tenant/{tenant_shard_id}/migrate"), Some(TenantShardMigrateRequest { tenant_shard_id, node_id, @@ -657,6 +660,20 @@ impl AttachmentService { .await } + #[instrument(skip(self), fields(%tenant_id, %new_shard_count))] + pub async fn tenant_split( + &self, + tenant_id: TenantId, + new_shard_count: u8, + ) -> anyhow::Result { + self.dispatch( + Method::PUT, + format!("control/v1/tenant/{tenant_id}/shard_split"), + Some(TenantShardSplitRequest { new_shard_count }), + ) + .await + } + #[instrument(skip_all, fields(node_id=%req.node_id))] pub async fn node_register(&self, req: NodeRegisterRequest) -> anyhow::Result<()> { self.dispatch::<_, ()>(Method::POST, "control/v1/node".to_string(), Some(req)) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index e56007dd20..b9af467fdf 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -575,6 +575,26 @@ async fn handle_tenant( println!("{tenant_table}"); println!("{shard_table}"); } + Some(("shard-split", matches)) => { + let tenant_id = get_tenant_id(matches, env)?; + let shard_count: u8 = matches.get_one::("shard-count").cloned().unwrap_or(0); + + let attachment_service = AttachmentService::from_env(env); + let result = attachment_service + .tenant_split(tenant_id, shard_count) + .await?; + println!( + "Split tenant {} into shards {}", + tenant_id, + result + .new_shards + .iter() + .map(|s| format!("{:?}", s)) + .collect::>() + .join(",") + ); + } + Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name), None => bail!("no tenant subcommand provided"), } @@ -1524,6 +1544,11 @@ fn cli() -> Command { .subcommand(Command::new("status") .about("Human readable summary of the tenant's shards and attachment locations") .arg(tenant_id_arg.clone())) + .subcommand(Command::new("shard-split") + .about("Increase the number of shards in the tenant") + .arg(tenant_id_arg.clone()) + .arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)")) + ) ) .subcommand( Command::new("pageserver") diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index c08cacb822..46324efd43 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -192,6 +192,16 @@ pub struct TimelineCreateRequest { pub pg_version: Option, } +#[derive(Serialize, Deserialize)] +pub struct TenantShardSplitRequest { + pub new_shard_count: u8, +} + +#[derive(Serialize, Deserialize)] +pub struct TenantShardSplitResponse { + pub new_shards: Vec, +} + /// Parameters that apply to all shards in a tenant. Used during tenant creation. #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] diff --git a/libs/pageserver_api/src/shard.rs b/libs/pageserver_api/src/shard.rs index e27aad8156..322b6c642e 100644 --- a/libs/pageserver_api/src/shard.rs +++ b/libs/pageserver_api/src/shard.rs @@ -88,12 +88,36 @@ impl TenantShardId { pub fn is_unsharded(&self) -> bool { self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0) } + + /// Convenience for dropping the tenant_id and just getting the ShardIndex: this + /// is useful when logging from code that is already in a span that includes tenant ID, to + /// keep messages reasonably terse. pub fn to_index(&self) -> ShardIndex { ShardIndex { shard_number: self.shard_number, shard_count: self.shard_count, } } + + /// Calculate the children of this TenantShardId when splitting the overall tenant into + /// the given number of shards. + pub fn split(&self, new_shard_count: ShardCount) -> Vec { + let effective_old_shard_count = std::cmp::max(self.shard_count.0, 1); + let mut child_shards = Vec::new(); + for shard_number in 0..ShardNumber(new_shard_count.0).0 { + // Key mapping is based on a round robin mapping of key hash modulo shard count, + // so our child shards are the ones which the same keys would map to. + if shard_number % effective_old_shard_count == self.shard_number.0 { + child_shards.push(TenantShardId { + tenant_id: self.tenant_id, + shard_number: ShardNumber(shard_number), + shard_count: new_shard_count, + }) + } + } + + child_shards + } } /// Formatting helper @@ -793,4 +817,108 @@ mod tests { let shard = key_to_shard_number(ShardCount(10), DEFAULT_STRIPE_SIZE, &key); assert_eq!(shard, ShardNumber(8)); } + + #[test] + fn shard_id_split() { + let tenant_id = TenantId::generate(); + let parent = TenantShardId::unsharded(tenant_id); + + // Unsharded into 2 + assert_eq!( + parent.split(ShardCount(2)), + vec![ + TenantShardId { + tenant_id, + shard_count: ShardCount(2), + shard_number: ShardNumber(0) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(2), + shard_number: ShardNumber(1) + } + ] + ); + + // Unsharded into 4 + assert_eq!( + parent.split(ShardCount(4)), + vec![ + TenantShardId { + tenant_id, + shard_count: ShardCount(4), + shard_number: ShardNumber(0) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(4), + shard_number: ShardNumber(1) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(4), + shard_number: ShardNumber(2) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(4), + shard_number: ShardNumber(3) + } + ] + ); + + // count=1 into 2 (check this works the same as unsharded.) + let parent = TenantShardId { + tenant_id, + shard_count: ShardCount(1), + shard_number: ShardNumber(0), + }; + assert_eq!( + parent.split(ShardCount(2)), + vec![ + TenantShardId { + tenant_id, + shard_count: ShardCount(2), + shard_number: ShardNumber(0) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(2), + shard_number: ShardNumber(1) + } + ] + ); + + // count=2 into count=8 + let parent = TenantShardId { + tenant_id, + shard_count: ShardCount(2), + shard_number: ShardNumber(1), + }; + assert_eq!( + parent.split(ShardCount(8)), + vec![ + TenantShardId { + tenant_id, + shard_count: ShardCount(8), + shard_number: ShardNumber(1) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(8), + shard_number: ShardNumber(3) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(8), + shard_number: ShardNumber(5) + }, + TenantShardId { + tenant_id, + shard_count: ShardCount(8), + shard_number: ShardNumber(7) + }, + ] + ); + } } diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index 8abe58e1a2..200369df90 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -310,6 +310,22 @@ impl Client { .map_err(Error::ReceiveBody) } + pub async fn tenant_shard_split( + &self, + tenant_shard_id: TenantShardId, + req: TenantShardSplitRequest, + ) -> Result { + let uri = format!( + "{}/v1/tenant/{}/shard_split", + self.mgmt_api_endpoint, tenant_shard_id + ); + self.request(Method::PUT, &uri, req) + .await? + .json() + .await + .map_err(Error::ReceiveBody) + } + pub async fn timeline_list( &self, tenant_shard_id: &TenantShardId, diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index ebcb27fa08..af9a3c7301 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -19,11 +19,14 @@ use pageserver_api::models::ShardParameters; use pageserver_api::models::TenantDetails; use pageserver_api::models::TenantLocationConfigResponse; use pageserver_api::models::TenantShardLocation; +use pageserver_api::models::TenantShardSplitRequest; +use pageserver_api::models::TenantShardSplitResponse; use pageserver_api::models::TenantState; use pageserver_api::models::{ DownloadRemoteLayersTaskSpawnRequest, LocationConfigMode, TenantAttachRequest, TenantLoadRequest, TenantLocationConfigRequest, }; +use pageserver_api::shard::ShardCount; use pageserver_api::shard::TenantShardId; use remote_storage::GenericRemoteStorage; use remote_storage::TimeTravelError; @@ -875,7 +878,7 @@ async fn tenant_reset_handler( let state = get_state(&request); state .tenant_manager - .reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), ctx) + .reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), &ctx) .await .map_err(ApiError::InternalServerError)?; @@ -1104,6 +1107,25 @@ async fn tenant_size_handler( ) } +async fn tenant_shard_split_handler( + mut request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + let req: TenantShardSplitRequest = json_request(&mut request).await?; + + let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?; + let state = get_state(&request); + let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn); + + let new_shards = state + .tenant_manager + .shard_split(tenant_shard_id, ShardCount(req.new_shard_count), &ctx) + .await + .map_err(ApiError::InternalServerError)?; + + json_response(StatusCode::OK, TenantShardSplitResponse { new_shards }) +} + async fn layer_map_info_handler( request: Request, _cancel: CancellationToken, @@ -2063,6 +2085,9 @@ pub fn make_router( .put("/v1/tenant/config", |r| { api_handler(r, update_tenant_config_handler) }) + .put("/v1/tenant/:tenant_shard_id/shard_split", |r| { + api_handler(r, tenant_shard_split_handler) + }) .get("/v1/tenant/:tenant_shard_id/config", |r| { api_handler(r, get_tenant_config_handler) }) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index f704f8c0dd..f086f46213 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -53,6 +53,7 @@ use self::metadata::TimelineMetadata; use self::mgr::GetActiveTenantError; use self::mgr::GetTenantError; use self::mgr::TenantsMap; +use self::remote_timeline_client::upload::upload_index_part; use self::remote_timeline_client::RemoteTimelineClient; use self::timeline::uninit::TimelineExclusionError; use self::timeline::uninit::TimelineUninitMark; @@ -2397,6 +2398,67 @@ impl Tenant { pub(crate) fn get_generation(&self) -> Generation { self.generation } + + /// This function partially shuts down the tenant (it shuts down the Timelines) and is fallible, + /// and can leave the tenant in a bad state if it fails. The caller is responsible for + /// resetting this tenant to a valid state if we fail. + pub(crate) async fn split_prepare( + &self, + child_shards: &Vec, + ) -> anyhow::Result<()> { + let timelines = self.timelines.lock().unwrap().clone(); + for timeline in timelines.values() { + let Some(tl_client) = &timeline.remote_client else { + anyhow::bail!("Remote storage is mandatory"); + }; + + let Some(remote_storage) = &self.remote_storage else { + anyhow::bail!("Remote storage is mandatory"); + }; + + // We do not block timeline creation/deletion during splits inside the pageserver: it is up to higher levels + // to ensure that they do not start a split if currently in the process of doing these. + + // Upload an index from the parent: this is partly to provide freshness for the + // child tenants that will copy it, and partly for general ease-of-debugging: there will + // always be a parent shard index in the same generation as we wrote the child shard index. + tl_client.schedule_index_upload_for_file_changes()?; + tl_client.wait_completion().await?; + + // Shut down the timeline's remote client: this means that the indices we write + // for child shards will not be invalidated by the parent shard deleting layers. + tl_client.shutdown().await?; + + // Download methods can still be used after shutdown, as they don't flow through the remote client's + // queue. In principal the RemoteTimelineClient could provide this without downloading it, but this + // operation is rare, so it's simpler to just download it (and robustly guarantees that the index + // we use here really is the remotely persistent one). + let result = tl_client + .download_index_file(self.cancel.clone()) + .instrument(info_span!("download_index_file", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id)) + .await?; + let index_part = match result { + MaybeDeletedIndexPart::Deleted(_) => { + anyhow::bail!("Timeline deletion happened concurrently with split") + } + MaybeDeletedIndexPart::IndexPart(p) => p, + }; + + for child_shard in child_shards { + upload_index_part( + remote_storage, + child_shard, + &timeline.timeline_id, + self.generation, + &index_part, + &self.cancel, + ) + .await?; + } + } + + Ok(()) + } } /// Given a Vec of timelines and their ancestors (timeline_id, ancestor_id), @@ -3732,6 +3794,10 @@ impl Tenant { Ok(()) } + + pub(crate) fn get_tenant_conf(&self) -> TenantConfOpt { + self.tenant_conf.read().unwrap().tenant_conf + } } fn remove_timeline_and_uninit_mark( diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 5ec910ca3e..9aee39bd35 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -2,6 +2,7 @@ //! page server. use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf}; +use itertools::Itertools; use pageserver_api::key::Key; use pageserver_api::models::ShardParameters; use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, TenantShardId}; @@ -22,7 +23,7 @@ use tokio_util::sync::CancellationToken; use tracing::*; use remote_storage::GenericRemoteStorage; -use utils::crashsafe; +use utils::{completion, crashsafe}; use crate::config::PageServerConf; use crate::context::{DownloadBehavior, RequestContext}; @@ -644,8 +645,6 @@ pub(crate) async fn shutdown_all_tenants() { } async fn shutdown_all_tenants0(tenants: &std::sync::RwLock) { - use utils::completion; - let mut join_set = JoinSet::new(); // Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants. @@ -1200,7 +1199,7 @@ impl TenantManager { &self, tenant_shard_id: TenantShardId, drop_cache: bool, - ctx: RequestContext, + ctx: &RequestContext, ) -> anyhow::Result<()> { let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; let Some(old_slot) = slot_guard.get_old_value() else { @@ -1253,7 +1252,7 @@ impl TenantManager { None, self.tenants, SpawnMode::Normal, - &ctx, + ctx, )?; slot_guard.upsert(TenantSlot::Attached(tenant))?; @@ -1375,6 +1374,164 @@ impl TenantManager { slot_guard.revert(); result } + + #[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.0))] + pub(crate) async fn shard_split( + &self, + tenant_shard_id: TenantShardId, + new_shard_count: ShardCount, + ctx: &RequestContext, + ) -> anyhow::Result> { + let tenant = get_tenant(tenant_shard_id, true)?; + + // Plan: identify what the new child shards will be + let effective_old_shard_count = std::cmp::max(tenant_shard_id.shard_count.0, 1); + if new_shard_count <= ShardCount(effective_old_shard_count) { + anyhow::bail!("Requested shard count is not an increase"); + } + let expansion_factor = new_shard_count.0 / effective_old_shard_count; + if !expansion_factor.is_power_of_two() { + anyhow::bail!("Requested split is not a power of two"); + } + + let parent_shard_identity = tenant.shard_identity; + let parent_tenant_conf = tenant.get_tenant_conf(); + let parent_generation = tenant.generation; + + let child_shards = tenant_shard_id.split(new_shard_count); + tracing::info!( + "Shard {} splits into: {}", + tenant_shard_id.to_index(), + child_shards + .iter() + .map(|id| format!("{}", id.to_index())) + .join(",") + ); + + // Phase 1: Write out child shards' remote index files, in the parent tenant's current generation + if let Err(e) = tenant.split_prepare(&child_shards).await { + // If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might + // have been left in a partially-shut-down state. + tracing::warn!("Failed to prepare for split: {e}, reloading Tenant before returning"); + self.reset_tenant(tenant_shard_id, false, ctx).await?; + return Err(e); + } + + self.resources.deletion_queue_client.flush_advisory(); + + // Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant + drop(tenant); + let mut parent_slot_guard = + tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; + let parent = match parent_slot_guard.get_old_value() { + Some(TenantSlot::Attached(t)) => t, + Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"), + Some(TenantSlot::InProgress(_)) => { + // tenant_map_acquire_slot never returns InProgress, if a slot was InProgress + // it would return an error. + unreachable!() + } + None => { + // We don't actually need the parent shard to still be attached to do our work, but it's + // a weird enough situation that the caller probably didn't want us to continue working + // if they had detached the tenant they requested the split on. + anyhow::bail!("Detached parent shard in the middle of split!") + } + }; + + // TODO: hardlink layers from the parent into the child shard directories so that they don't immediately re-download + // TODO: erase the dentries from the parent + + // Take a snapshot of where the parent's WAL ingest had got to: we will wait for + // child shards to reach this point. + let mut target_lsns = HashMap::new(); + for timeline in parent.timelines.lock().unwrap().clone().values() { + target_lsns.insert(timeline.timeline_id, timeline.get_last_record_lsn()); + } + + // TODO: we should have the parent shard stop its WAL ingest here, it's a waste of resources + // and could slow down the children trying to catch up. + + // Phase 3: Spawn the child shards + for child_shard in &child_shards { + let mut child_shard_identity = parent_shard_identity; + child_shard_identity.count = child_shard.shard_count; + child_shard_identity.number = child_shard.shard_number; + + let child_location_conf = LocationConf { + mode: LocationMode::Attached(AttachedLocationConfig { + generation: parent_generation, + attach_mode: AttachmentMode::Single, + }), + shard: child_shard_identity, + tenant_conf: parent_tenant_conf, + }; + + self.upsert_location( + *child_shard, + child_location_conf, + None, + SpawnMode::Normal, + ctx, + ) + .await?; + } + + // Phase 4: wait for child chards WAL ingest to catch up to target LSN + for child_shard_id in &child_shards { + let child_shard = { + let locked = TENANTS.read().unwrap(); + let peek_slot = + tenant_map_peek_slot(&locked, child_shard_id, TenantSlotPeekMode::Read)?; + peek_slot.and_then(|s| s.get_attached()).cloned() + }; + if let Some(t) = child_shard { + let timelines = t.timelines.lock().unwrap().clone(); + for timeline in timelines.values() { + let Some(target_lsn) = target_lsns.get(&timeline.timeline_id) else { + continue; + }; + + tracing::info!( + "Waiting for child shard {}/{} to reach target lsn {}...", + child_shard_id, + timeline.timeline_id, + target_lsn + ); + if let Err(e) = timeline.wait_lsn(*target_lsn, ctx).await { + // Failure here might mean shutdown, in any case this part is an optimization + // and we shouldn't hold up the split operation. + tracing::warn!( + "Failed to wait for timeline {} to reach lsn {target_lsn}: {e}", + timeline.timeline_id + ); + } else { + tracing::info!( + "Child shard {}/{} reached target lsn {}", + child_shard_id, + timeline.timeline_id, + target_lsn + ); + } + } + } + } + + // Phase 5: Shut down the parent shard. + let (_guard, progress) = completion::channel(); + match parent.shutdown(progress, false).await { + Ok(()) => {} + Err(other) => { + other.wait().await; + } + } + parent_slot_guard.drop_old_value()?; + + // Phase 6: Release the InProgress on the parent shard + drop(parent_slot_guard); + + Ok(child_shards) + } } #[derive(Debug, thiserror::Error)] @@ -2209,8 +2366,6 @@ async fn remove_tenant_from_memory( where F: std::future::Future>, { - use utils::completion; - let mut slot_guard = tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?; diff --git a/pageserver/src/tenant/remote_timeline_client/upload.rs b/pageserver/src/tenant/remote_timeline_client/upload.rs index e8ba1d3d6e..c17e27b446 100644 --- a/pageserver/src/tenant/remote_timeline_client/upload.rs +++ b/pageserver/src/tenant/remote_timeline_client/upload.rs @@ -27,7 +27,7 @@ use super::index::LayerFileMetadata; use tracing::info; /// Serializes and uploads the given index part data to the remote storage. -pub(super) async fn upload_index_part<'a>( +pub(crate) async fn upload_index_part<'a>( storage: &'a GenericRemoteStorage, tenant_shard_id: &TenantShardId, timeline_id: &TimelineId, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 4491655aeb..3d2549a8c3 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4054,7 +4054,7 @@ def logical_replication_sync(subscriber: VanillaPostgres, publisher: Endpoint) - def tenant_get_shards( - env: NeonEnv, tenant_id: TenantId, pageserver_id: Optional[int] + env: NeonEnv, tenant_id: TenantId, pageserver_id: Optional[int] = None ) -> list[tuple[TenantShardId, NeonPageserver]]: """ Helper for when you want to talk to one or more pageservers, and the diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index c16bfc2ec6..805eaa34b0 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1,6 +1,7 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, + tenant_get_shards, ) from fixtures.remote_storage import s3_storage from fixtures.types import TimelineId @@ -82,4 +83,130 @@ def test_sharding_smoke( ) assert timelines == {env.initial_timeline, timeline_b} - # TODO: test timeline deletion and tenant deletion (depends on change in attachment_service) + +def test_sharding_split_smoke( + neon_env_builder: NeonEnvBuilder, +): + """ + Test the basics of shard splitting: + - The API results in more shards than we started with + - The tenant's data remains readable + + """ + + # We will start with 4 shards and split into 8, then migrate all those + # 8 shards onto separate pageservers + shard_count = 4 + split_shard_count = 8 + neon_env_builder.num_pageservers = split_shard_count + + # 1MiB stripes: enable getting some meaningful data distribution without + # writing large quantities of data in this test. The stripe size is given + # in number of 8KiB pages. + stripe_size = 128 + + # Use S3-compatible remote storage so that we can scrub: this test validates + # that the scrubber doesn't barf when it sees a sharded tenant. + neon_env_builder.enable_pageserver_remote_storage(s3_storage()) + neon_env_builder.enable_scrub_on_exit() + + neon_env_builder.preserve_database_files = True + + env = neon_env_builder.init_start( + initial_tenant_shard_count=shard_count, initial_tenant_shard_stripe_size=stripe_size + ) + tenant_id = env.initial_tenant + timeline_id = env.initial_timeline + workload = Workload(env, tenant_id, timeline_id, branch_name="main") + workload.init() + + # Initial data + workload.write_rows(256) + + # Note which pageservers initially hold a shard after tenant creation + pre_split_pageserver_ids = [loc["node_id"] for loc in env.attachment_service.locate(tenant_id)] + + # For pageservers holding a shard, validate their ingest statistics + # reflect a proper splitting of the WAL. + for pageserver in env.pageservers: + if pageserver.id not in pre_split_pageserver_ids: + continue + + metrics = pageserver.http_client().get_metrics_values( + [ + "pageserver_wal_ingest_records_received_total", + "pageserver_wal_ingest_records_committed_total", + "pageserver_wal_ingest_records_filtered_total", + ] + ) + + log.info(f"Pageserver {pageserver.id} metrics: {metrics}") + + # Not everything received was committed + assert ( + metrics["pageserver_wal_ingest_records_received_total"] + > metrics["pageserver_wal_ingest_records_committed_total"] + ) + + # Something was committed + assert metrics["pageserver_wal_ingest_records_committed_total"] > 0 + + # Counts are self consistent + assert ( + metrics["pageserver_wal_ingest_records_received_total"] + == metrics["pageserver_wal_ingest_records_committed_total"] + + metrics["pageserver_wal_ingest_records_filtered_total"] + ) + + # TODO: validate that shards have different sizes + + workload.validate() + + assert len(pre_split_pageserver_ids) == 4 + + env.attachment_service.tenant_shard_split(tenant_id, shard_count=split_shard_count) + + post_split_pageserver_ids = [loc["node_id"] for loc in env.attachment_service.locate(tenant_id)] + # We should have split into 8 shards, on the same 4 pageservers we started on. + assert len(post_split_pageserver_ids) == split_shard_count + assert len(set(post_split_pageserver_ids)) == shard_count + assert set(post_split_pageserver_ids) == set(pre_split_pageserver_ids) + + workload.validate() + + workload.churn_rows(256) + + workload.validate() + + # Run GC on all new shards, to check they don't barf or delete anything that breaks reads + # (compaction was already run as part of churn_rows) + all_shards = tenant_get_shards(env, tenant_id) + for tenant_shard_id, pageserver in all_shards: + pageserver.http_client().timeline_gc(tenant_shard_id, timeline_id, None) + + # Restart all nodes, to check that the newly created shards are durable + for ps in env.pageservers: + ps.restart() + + workload.validate() + + migrate_to_pageserver_ids = list( + set(p.id for p in env.pageservers) - set(pre_split_pageserver_ids) + ) + assert len(migrate_to_pageserver_ids) == split_shard_count - shard_count + + # Migrate shards away from the node where the split happened + for ps_id in pre_split_pageserver_ids: + shards_here = [ + tenant_shard_id + for (tenant_shard_id, pageserver) in all_shards + if pageserver.id == ps_id + ] + assert len(shards_here) == 2 + migrate_shard = shards_here[0] + destination = migrate_to_pageserver_ids.pop() + + log.info(f"Migrating shard {migrate_shard} from {ps_id} to {destination}") + env.neon_cli.tenant_migrate(migrate_shard, destination, timeout_secs=10) + + workload.validate() From e8d2843df63ba05cd74baa8017736a903f9a322a Mon Sep 17 00:00:00 2001 From: John Spray Date: Thu, 8 Feb 2024 18:00:53 +0000 Subject: [PATCH 18/81] storage controller: improved handling of node availability on restart (#6658) - Automatically set a node's availability to Active if it is responsive in startup_reconcile - Impose a 5s timeout of HTTP request to list location conf, so that an unresponsive node can't hang it for minutes - Do several retries if the request fails with a retryable error, to be tolerant of concurrent pageserver & storage controller restarts - Add a readiness hook for use with k8s so that we can tell when the startup reconciliaton is done and the service is fully ready to do work. - Add /metrics to the list of un-authenticated endpoints (this is unrelated but we're touching the line in this PR already, and it fixes auth error spam in deployed container.) - A test for the above. Closes: #6670 --- control_plane/attachment_service/src/http.rs | 14 ++- .../attachment_service/src/service.rs | 107 +++++++++++++----- libs/utils/src/completion.rs | 5 + pageserver/client/src/mgmt_api.rs | 10 +- test_runner/fixtures/neon_fixtures.py | 9 ++ test_runner/regress/test_sharding_service.py | 32 ++++++ 6 files changed, 149 insertions(+), 28 deletions(-) diff --git a/control_plane/attachment_service/src/http.rs b/control_plane/attachment_service/src/http.rs index 38eecaf7ef..8501e4980f 100644 --- a/control_plane/attachment_service/src/http.rs +++ b/control_plane/attachment_service/src/http.rs @@ -42,7 +42,7 @@ pub struct HttpState { impl HttpState { pub fn new(service: Arc, auth: Option>) -> Self { - let allowlist_routes = ["/status"] + let allowlist_routes = ["/status", "/ready", "/metrics"] .iter() .map(|v| v.parse().unwrap()) .collect::>(); @@ -325,6 +325,17 @@ async fn handle_status(_req: Request) -> Result, ApiError> json_response(StatusCode::OK, ()) } +/// Readiness endpoint indicates when we're done doing startup I/O (e.g. reconciling +/// with remote pageserver nodes). This is intended for use as a kubernetes readiness probe. +async fn handle_ready(req: Request) -> Result, ApiError> { + let state = get_state(&req); + if state.service.startup_complete.is_ready() { + json_response(StatusCode::OK, ()) + } else { + json_response(StatusCode::SERVICE_UNAVAILABLE, ()) + } +} + impl From for ApiError { fn from(value: ReconcileError) -> Self { ApiError::Conflict(format!("Reconciliation error: {}", value)) @@ -380,6 +391,7 @@ pub fn make_router( .data(Arc::new(HttpState::new(service, auth))) // Non-prefixed generic endpoints (status, metrics) .get("/status", |r| request_span(r, handle_status)) + .get("/ready", |r| request_span(r, handle_ready)) // Upcalls for the pageserver: point the pageserver's `control_plane_api` config to this prefix .post("/upcall/v1/re-attach", |r| { request_span(r, handle_re_attach) diff --git a/control_plane/attachment_service/src/service.rs b/control_plane/attachment_service/src/service.rs index 0ec2b9dc4c..0331087e0d 100644 --- a/control_plane/attachment_service/src/service.rs +++ b/control_plane/attachment_service/src/service.rs @@ -1,6 +1,6 @@ use std::{ cmp::Ordering, - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, str::FromStr, sync::Arc, time::{Duration, Instant}, @@ -31,6 +31,7 @@ use pageserver_api::{ use pageserver_client::mgmt_api; use tokio_util::sync::CancellationToken; use utils::{ + backoff, completion::Barrier, generation::Generation, http::error::ApiError, @@ -150,31 +151,71 @@ impl Service { // indeterminate, same as in [`ObservedStateLocation`]) let mut observed = HashMap::new(); - let nodes = { - let locked = self.inner.read().unwrap(); - locked.nodes.clone() - }; + let mut nodes_online = HashSet::new(); + + // TODO: give Service a cancellation token for clean shutdown + let cancel = CancellationToken::new(); // TODO: issue these requests concurrently - for node in nodes.values() { - let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref()); + { + let nodes = { + let locked = self.inner.read().unwrap(); + locked.nodes.clone() + }; + for node in nodes.values() { + let http_client = reqwest::ClientBuilder::new() + .timeout(Duration::from_secs(5)) + .build() + .expect("Failed to construct HTTP client"); + let client = mgmt_api::Client::from_client( + http_client, + node.base_url(), + self.config.jwt_token.as_deref(), + ); - tracing::info!("Scanning shards on node {}...", node.id); - match client.list_location_config().await { - Err(e) => { - tracing::warn!("Could not contact pageserver {} ({e})", node.id); - // TODO: be more tolerant, apply a generous 5-10 second timeout with retries, in case - // pageserver is being restarted at the same time as we are + fn is_fatal(e: &mgmt_api::Error) -> bool { + use mgmt_api::Error::*; + match e { + ReceiveBody(_) | ReceiveErrorBody(_) => false, + ApiError(StatusCode::SERVICE_UNAVAILABLE, _) + | ApiError(StatusCode::GATEWAY_TIMEOUT, _) + | ApiError(StatusCode::REQUEST_TIMEOUT, _) => false, + ApiError(_, _) => true, + } } - Ok(listing) => { - tracing::info!( - "Received {} shard statuses from pageserver {}, setting it to Active", - listing.tenant_shards.len(), - node.id - ); - for (tenant_shard_id, conf_opt) in listing.tenant_shards { - observed.insert(tenant_shard_id, (node.id, conf_opt)); + let list_response = backoff::retry( + || client.list_location_config(), + is_fatal, + 1, + 5, + "Location config listing", + &cancel, + ) + .await; + let Some(list_response) = list_response else { + tracing::info!("Shutdown during startup_reconcile"); + return; + }; + + tracing::info!("Scanning shards on node {}...", node.id); + match list_response { + Err(e) => { + tracing::warn!("Could not contact pageserver {} ({e})", node.id); + // TODO: be more tolerant, do some retries, in case + // pageserver is being restarted at the same time as we are + } + Ok(listing) => { + tracing::info!( + "Received {} shard statuses from pageserver {}, setting it to Active", + listing.tenant_shards.len(), + node.id + ); + nodes_online.insert(node.id); + + for (tenant_shard_id, conf_opt) in listing.tenant_shards { + observed.insert(tenant_shard_id, (node.id, conf_opt)); + } } } } @@ -185,8 +226,19 @@ impl Service { let mut compute_notifications = Vec::new(); // Populate intent and observed states for all tenants, based on reported state on pageservers - let shard_count = { + let (shard_count, nodes) = { let mut locked = self.inner.write().unwrap(); + + // Mark nodes online if they responded to us: nodes are offline by default after a restart. + let mut nodes = (*locked.nodes).clone(); + for (node_id, node) in nodes.iter_mut() { + if nodes_online.contains(node_id) { + node.availability = NodeAvailability::Active; + } + } + locked.nodes = Arc::new(nodes); + let nodes = locked.nodes.clone(); + for (tenant_shard_id, (node_id, observed_loc)) in observed { let Some(tenant_state) = locked.tenants.get_mut(&tenant_shard_id) else { cleanup.push((tenant_shard_id, node_id)); @@ -218,7 +270,7 @@ impl Service { } } - locked.tenants.len() + (locked.tenants.len(), nodes) }; // TODO: if any tenant's intent now differs from its loaded generation_pageserver, we should clear that @@ -279,9 +331,8 @@ impl Service { let stream = futures::stream::iter(compute_notifications.into_iter()) .map(|(tenant_shard_id, node_id)| { let compute_hook = compute_hook.clone(); + let cancel = cancel.clone(); async move { - // TODO: give Service a cancellation token for clean shutdown - let cancel = CancellationToken::new(); if let Err(e) = compute_hook.notify(tenant_shard_id, node_id, &cancel).await { tracing::error!( tenant_shard_id=%tenant_shard_id, @@ -387,7 +438,7 @@ impl Service { ))), config, persistence, - startup_complete, + startup_complete: startup_complete.clone(), }); let result_task_this = this.clone(); @@ -984,6 +1035,10 @@ impl Service { } }; + // TODO: if we timeout/fail on reconcile, we should still succeed this request, + // because otherwise a broken compute hook causes a feedback loop where + // location_config returns 500 and gets retried forever. + if let Some(create_req) = maybe_create { let create_resp = self.tenant_create(create_req).await?; result.shards = create_resp diff --git a/libs/utils/src/completion.rs b/libs/utils/src/completion.rs index ca6827c9b8..ea05cf54b1 100644 --- a/libs/utils/src/completion.rs +++ b/libs/utils/src/completion.rs @@ -27,6 +27,11 @@ impl Barrier { b.wait().await } } + + /// Return true if a call to wait() would complete immediately + pub fn is_ready(&self) -> bool { + futures::future::FutureExt::now_or_never(self.0.wait()).is_some() + } } impl PartialEq for Barrier { diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index 200369df90..baea747d3c 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -56,10 +56,18 @@ pub enum ForceAwaitLogicalSize { impl Client { pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self { + Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt) + } + + pub fn from_client( + client: reqwest::Client, + mgmt_api_endpoint: String, + jwt: Option<&str>, + ) -> Self { Self { mgmt_api_endpoint, authorization_header: jwt.map(|jwt| format!("Bearer {jwt}")), - client: reqwest::Client::new(), + client, } } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 3d2549a8c3..0af8098cad 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1949,6 +1949,15 @@ class NeonAttachmentService: return headers + def ready(self) -> bool: + resp = self.request("GET", f"{self.env.attachment_service_api}/ready") + if resp.status_code == 503: + return False + elif resp.status_code == 200: + return True + else: + raise RuntimeError(f"Unexpected status {resp.status_code} from readiness endpoint") + def attach_hook_issue( self, tenant_shard_id: Union[TenantId, TenantShardId], pageserver_id: int ) -> int: diff --git a/test_runner/regress/test_sharding_service.py b/test_runner/regress/test_sharding_service.py index fd811a9d02..babb0d261c 100644 --- a/test_runner/regress/test_sharding_service.py +++ b/test_runner/regress/test_sharding_service.py @@ -128,6 +128,38 @@ def test_sharding_service_smoke( assert counts[env.pageservers[2].id] == tenant_shard_count // 2 +def test_node_status_after_restart( + neon_env_builder: NeonEnvBuilder, +): + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_start() + + # Initially we have two online pageservers + nodes = env.attachment_service.node_list() + assert len(nodes) == 2 + + env.pageservers[1].stop() + + env.attachment_service.stop() + env.attachment_service.start() + + # Initially readiness check should fail because we're trying to connect to the offline node + assert env.attachment_service.ready() is False + + def is_ready(): + assert env.attachment_service.ready() is True + + wait_until(30, 1, is_ready) + + # We loaded nodes from database on restart + nodes = env.attachment_service.node_list() + assert len(nodes) == 2 + + # We should still be able to create a tenant, because the pageserver which is still online + # should have had its availabilty state set to Active. + env.attachment_service.tenant_create(TenantId.generate()) + + def test_sharding_service_passthrough( neon_env_builder: NeonEnvBuilder, ): From c0e0fc8151f2c00d45ebb8e39ef3c271c65a38f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Thu, 8 Feb 2024 19:57:02 +0100 Subject: [PATCH 19/81] Update Rust to 1.76.0 (#6683) [Release notes](https://github.com/rust-lang/rust/releases/tag/1.75.0). --- Dockerfile.buildtools | 2 +- compute_tools/src/pg_helpers.rs | 5 +++-- control_plane/src/background_process.rs | 1 - rust-toolchain.toml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Dockerfile.buildtools b/Dockerfile.buildtools index 220e995d64..3a452fec32 100644 --- a/Dockerfile.buildtools +++ b/Dockerfile.buildtools @@ -135,7 +135,7 @@ WORKDIR /home/nonroot # Rust # Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`) -ENV RUSTC_VERSION=1.75.0 +ENV RUSTC_VERSION=1.76.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \ diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index ce704385c6..5deb50d6b7 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -264,9 +264,10 @@ pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> { // case we miss some events for some reason. Not strictly necessary, but // better safe than sorry. let (tx, rx) = std::sync::mpsc::channel(); - let (mut watcher, rx): (Box, _) = match notify::recommended_watcher(move |res| { + let watcher_res = notify::recommended_watcher(move |res| { let _ = tx.send(res); - }) { + }); + let (mut watcher, rx): (Box, _) = match watcher_res { Ok(watcher) => (Box::new(watcher), rx), Err(e) => { match e.kind { diff --git a/control_plane/src/background_process.rs b/control_plane/src/background_process.rs index 364cc01c39..0e59b28230 100644 --- a/control_plane/src/background_process.rs +++ b/control_plane/src/background_process.rs @@ -72,7 +72,6 @@ where let log_path = datadir.join(format!("{process_name}.log")); let process_log_file = fs::OpenOptions::new() .create(true) - .write(true) .append(true) .open(&log_path) .with_context(|| { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 9b5a965f7d..b0949c32b1 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.75.0" +channel = "1.76.0" profile = "default" # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. # https://rust-lang.github.io/rustup/concepts/profiles.html From 9a31311990d19eb607e087e0e12d4369bfab8b6c Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Thu, 8 Feb 2024 22:40:14 +0200 Subject: [PATCH 20/81] fix(heavier_once_cell): assertion failure can be hit (#6652) @problame noticed that the `tokio::sync::AcquireError` branch assertion can be hit like in the first commit. We haven't seen this yet in production, but I'd prefer not to see it there. There `take_and_deinit` is being used, but this race must be quite timing sensitive. --- libs/utils/src/sync/heavier_once_cell.rs | 241 +++++++++++++++++------ 1 file changed, 176 insertions(+), 65 deletions(-) diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index f733d107f1..81625b907e 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -69,37 +69,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { + loop { + let sem = { + let guard = self.inner.write().await; + if guard.value.is_some() { + return Ok(GuardMut(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.write().await; + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(GuardMut(guard)); + }; + + permit.forget(); + } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + let guard = self.inner.write().await; - if guard.value.is_some() { - return Ok(GuardMut(guard)); - } - guard.init_semaphore.clone() - }; - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.write().await; - - Ok(Self::set0(value, guard)) - } - Err(_closed) => { - let guard = self.inner.write().await; - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardMut(guard)); - } + return Ok(Self::set0(value, guard)); } } @@ -112,37 +119,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { - let guard = self.inner.read().await; - if guard.value.is_some() { - return Ok(GuardRef(guard)); - } - guard.init_semaphore.clone() - }; - - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.write().await; - - Ok(Self::set0(value, guard).downgrade()) - } - Err(_closed) => { + loop { + let sem = { let guard = self.inner.read().await; - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardRef(guard)); + if guard.value.is_some() { + return Ok(GuardRef(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.read().await; + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(GuardRef(guard)); + }; + + permit.forget(); } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + + let guard = self.inner.write().await; + + return Ok(Self::set0(value, guard).downgrade()); } } @@ -250,15 +264,12 @@ impl<'a, T> GuardMut<'a, T> { /// [`OnceCell::get_or_init`] will wait on it to complete. pub fn take_and_deinit(&mut self) -> (T, InitPermit) { let mut swapped = Inner::default(); - let permit = swapped - .init_semaphore - .clone() - .try_acquire_owned() - .expect("we just created this"); + let sem = swapped.init_semaphore.clone(); + sem.try_acquire().expect("we just created this").forget(); std::mem::swap(&mut *self.0, &mut swapped); swapped .value - .map(|v| (v, InitPermit(permit))) + .map(|v| (v, InitPermit(sem))) .expect("guard is not created unless value has been initialized") } @@ -282,13 +293,23 @@ impl std::ops::Deref for GuardRef<'_, T> { } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(tokio::sync::OwnedSemaphorePermit); +pub struct InitPermit(Arc); + +impl Drop for InitPermit { + fn drop(&mut self) { + debug_assert_eq!(self.0.available_permits(), 0); + self.0.add_permits(1); + } +} #[cfg(test)] mod tests { + use futures::Future; + use super::*; use std::{ convert::Infallible, + pin::{pin, Pin}, sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -455,4 +476,94 @@ mod tests { .unwrap(); assert_eq!(*g, "now initialized"); } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race_mut() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_mut_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + type BoxedInitFuture = Pin>>>; + type BoxedInitFunction = Box BoxedInitFuture>; + + /// Reproduce an assertion failure with both initialization methods. + /// + /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`. + /// Alternative would be a macro_rules! but that is the last resort. + async fn init_take_deinit_scenario(init_way: F) + where + F: for<'a> Fn( + &'a OnceCell<&'static str>, + BoxedInitFunction<&'static str, Infallible>, + ) -> Pin + 'a>>, + { + let cell = OnceCell::default(); + + // acquire the init_semaphore only permit to drive initializing tasks in order to waiting + // on the same semaphore. + let permit = cell + .inner + .read() + .await + .init_semaphore + .clone() + .try_acquire_owned() + .unwrap(); + + let mut t1 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), + )); + + let mut t2 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), + )); + + // drive t2 first to the init_semaphore + tokio::select! { + _ = &mut t2 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // followed by t1 in the init_semaphore + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // now let t2 proceed and initialize + drop(permit); + t2.await; + + let (s, permit) = { cell.get_mut().await.unwrap().take_and_deinit() }; + assert_eq!("t2", s); + + // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from + // the new one. + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // only now we get to initialize it + drop(permit); + t1.await; + + assert_eq!("t1", *cell.get().await.unwrap()); + } } From c09993396ea026758bfda83c477361d656a5b647 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Fri, 9 Feb 2024 00:37:57 +0200 Subject: [PATCH 21/81] fix: secondary tenant relative order eviction (#6491) Calculate the `relative_last_activity` using the total evicted and resident layers similar to what we originally planned. Cc: #5331 --- pageserver/src/disk_usage_eviction_task.rs | 73 +++++++++++++------ pageserver/src/tenant/secondary.rs | 2 +- pageserver/src/tenant/secondary/downloader.rs | 27 ++++--- 3 files changed, 67 insertions(+), 35 deletions(-) diff --git a/pageserver/src/disk_usage_eviction_task.rs b/pageserver/src/disk_usage_eviction_task.rs index 1f0525b045..d5f5a20683 100644 --- a/pageserver/src/disk_usage_eviction_task.rs +++ b/pageserver/src/disk_usage_eviction_task.rs @@ -623,6 +623,7 @@ impl std::fmt::Display for EvictionLayer { } } +#[derive(Default)] pub(crate) struct DiskUsageEvictionInfo { /// Timeline's largest layer (remote or resident) pub max_layer_size: Option, @@ -854,19 +855,27 @@ async fn collect_eviction_candidates( let total = tenant_candidates.len(); - for (i, mut candidate) in tenant_candidates.into_iter().enumerate() { - // as we iterate this reverse sorted list, the most recently accessed layer will always - // be 1.0; this is for us to evict it last. - candidate.relative_last_activity = eviction_order.relative_last_activity(total, i); + let tenant_candidates = + tenant_candidates + .into_iter() + .enumerate() + .map(|(i, mut candidate)| { + // as we iterate this reverse sorted list, the most recently accessed layer will always + // be 1.0; this is for us to evict it last. + candidate.relative_last_activity = + eviction_order.relative_last_activity(total, i); - let partition = if cumsum > min_resident_size as i128 { - MinResidentSizePartition::Above - } else { - MinResidentSizePartition::Below - }; - cumsum += i128::from(candidate.layer.get_file_size()); - candidates.push((partition, candidate)); - } + let partition = if cumsum > min_resident_size as i128 { + MinResidentSizePartition::Above + } else { + MinResidentSizePartition::Below + }; + cumsum += i128::from(candidate.layer.get_file_size()); + + (partition, candidate) + }); + + candidates.extend(tenant_candidates); } // Note: the same tenant ID might be hit twice, if it transitions from attached to @@ -882,21 +891,41 @@ async fn collect_eviction_candidates( ); for secondary_tenant in secondary_tenants { - let mut layer_info = secondary_tenant.get_layers_for_eviction(); + // for secondary tenants we use a sum of on_disk layers and already evicted layers. this is + // to prevent repeated disk usage based evictions from completely draining less often + // updating secondaries. + let (mut layer_info, total_layers) = secondary_tenant.get_layers_for_eviction(); + + debug_assert!( + total_layers >= layer_info.resident_layers.len(), + "total_layers ({total_layers}) must be at least the resident_layers.len() ({})", + layer_info.resident_layers.len() + ); layer_info .resident_layers .sort_unstable_by_key(|layer_info| std::cmp::Reverse(layer_info.last_activity_ts)); - candidates.extend(layer_info.resident_layers.into_iter().map(|candidate| { - ( - // Secondary locations' layers are always considered above the min resident size, - // i.e. secondary locations are permitted to be trimmed to zero layers if all - // the layers have sufficiently old access times. - MinResidentSizePartition::Above, - candidate, - ) - })); + let tenant_candidates = + layer_info + .resident_layers + .into_iter() + .enumerate() + .map(|(i, mut candidate)| { + candidate.relative_last_activity = + eviction_order.relative_last_activity(total_layers, i); + ( + // Secondary locations' layers are always considered above the min resident size, + // i.e. secondary locations are permitted to be trimmed to zero layers if all + // the layers have sufficiently old access times. + MinResidentSizePartition::Above, + candidate, + ) + }); + + candidates.extend(tenant_candidates); + + tokio::task::yield_now().await; } debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below, diff --git a/pageserver/src/tenant/secondary.rs b/pageserver/src/tenant/secondary.rs index 4269e1dec1..926cd0302b 100644 --- a/pageserver/src/tenant/secondary.rs +++ b/pageserver/src/tenant/secondary.rs @@ -160,7 +160,7 @@ impl SecondaryTenant { &self.tenant_shard_id } - pub(crate) fn get_layers_for_eviction(self: &Arc) -> DiskUsageEvictionInfo { + pub(crate) fn get_layers_for_eviction(self: &Arc) -> (DiskUsageEvictionInfo, usize) { self.detail.lock().unwrap().get_layers_for_eviction(self) } diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 55af4f9f2b..9330edf946 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -146,14 +146,15 @@ impl SecondaryDetail { } } + /// Additionally returns the total number of layers, used for more stable relative access time + /// based eviction. pub(super) fn get_layers_for_eviction( &self, parent: &Arc, - ) -> DiskUsageEvictionInfo { - let mut result = DiskUsageEvictionInfo { - max_layer_size: None, - resident_layers: Vec::new(), - }; + ) -> (DiskUsageEvictionInfo, usize) { + let mut result = DiskUsageEvictionInfo::default(); + let mut total_layers = 0; + for (timeline_id, timeline_detail) in &self.timelines { result .resident_layers @@ -169,6 +170,10 @@ impl SecondaryDetail { relative_last_activity: finite_f32::FiniteF32::ZERO, } })); + + // total might be missing currently downloading layers, but as a lower than actual + // value it is good enough approximation. + total_layers += timeline_detail.on_disk_layers.len() + timeline_detail.evicted_at.len(); } result.max_layer_size = result .resident_layers @@ -183,7 +188,7 @@ impl SecondaryDetail { result.resident_layers.len() ); - result + (result, total_layers) } } @@ -312,9 +317,7 @@ impl JobGenerator Date: Fri, 9 Feb 2024 08:14:41 +0200 Subject: [PATCH 22/81] Increment generation which LFC is disabled by assigning 0 to neon.file_cache_size_limit (#6692) ## Problem test_lfc_resize sometimes filed with assertion failure when require lock in write operation: ``` if (lfc_ctl->generation == generation) { Assert(LFC_ENABLED()); ``` ## Summary of changes Increment generation when 0 is assigned to neon.file_cache_size_limit ## Checklist before requesting a review - [ ] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist Co-authored-by: Konstantin Knizhnik --- pgxn/neon/file_cache.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 21db666caa..448b9263f3 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -314,6 +314,9 @@ lfc_change_limit_hook(int newval, void *extra) lfc_ctl->used -= 1; } lfc_ctl->limit = new_size; + if (new_size == 0) { + lfc_ctl->generation += 1; + } neon_log(DEBUG1, "set local file cache limit to %d", new_size); LWLockRelease(lfc_lock); From a18aa14754fc44f7b38970bc546e4340386c32c9 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Fri, 9 Feb 2024 11:01:07 +0200 Subject: [PATCH 23/81] test: shutdown endpoints before deletion (#6619) this avoids a page_service error in the log sometimes. keeping the endpoint running while deleting has no function for this test. --- test_runner/regress/test_timeline_delete.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test_runner/regress/test_timeline_delete.py b/test_runner/regress/test_timeline_delete.py index 352b82d525..5fda5aa569 100644 --- a/test_runner/regress/test_timeline_delete.py +++ b/test_runner/regress/test_timeline_delete.py @@ -651,9 +651,7 @@ def test_timeline_delete_works_for_remote_smoke( timeline_ids = [env.initial_timeline] for i in range(2): branch_timeline_id = env.neon_cli.create_branch(f"new{i}", "main") - pg = env.endpoints.create_start(f"new{i}") - - with pg.cursor() as cur: + with env.endpoints.create_start(f"new{i}") as pg, pg.cursor() as cur: cur.execute("CREATE TABLE f (i integer);") cur.execute("INSERT INTO f VALUES (generate_series(1,1000));") current_lsn = Lsn(query_scalar(cur, "SELECT pg_current_wal_flush_lsn()")) From 568f91420a9c677e77aeb736cb3f995a85f0b106 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Fri, 9 Feb 2024 11:34:15 +0200 Subject: [PATCH 24/81] tests: try to make restored-datadir comparison tests not flaky (#6666) This test occasionally fails with a difference in "pg_xact/0000" file between the local and restored datadirs. My hypothesis is that something changed in the database between the last explicit checkpoint and the shutdown. I suspect autovacuum, it could certainly create transactions. To fix, be more precise about the point in time that we compare. Shut down the endpoint first, then read the last LSN (i.e. the shutdown checkpoint's LSN), from the local disk with pg_controldata. And use exactly that LSN in the basebackup. Closes #559. I'm proposing this as an alternative to https://github.com/neondatabase/neon/pull/6662. --- test_runner/fixtures/neon_fixtures.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 0af8098cad..a6aff77ddf 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3964,24 +3964,27 @@ def list_files_to_compare(pgdata_dir: Path) -> List[str]: # pg is the existing and running compute node, that we want to compare with a basebackup def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint: Endpoint): + pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) + # Get the timeline ID. We need it for the 'basebackup' command timeline_id = TimelineId(endpoint.safe_psql("SHOW neon.timeline_id")[0][0]) - # many tests already checkpoint, but do it just in case - with closing(endpoint.connect()) as conn: - with conn.cursor() as cur: - cur.execute("CHECKPOINT") - - # wait for pageserver to catch up - wait_for_last_flush_lsn(env, endpoint, endpoint.tenant_id, timeline_id) # stop postgres to ensure that files won't change endpoint.stop() + # Read the shutdown checkpoint's LSN + pg_controldata_path = os.path.join(pg_bin.pg_bin_path, "pg_controldata") + cmd = f"{pg_controldata_path} -D {endpoint.pgdata_dir}" + result = subprocess.run(cmd, capture_output=True, text=True, shell=True) + checkpoint_lsn = re.findall( + "Latest checkpoint location:\\s+([0-9A-F]+/[0-9A-F]+)", result.stdout + )[0] + log.debug(f"last checkpoint at {checkpoint_lsn}") + # Take a basebackup from pageserver restored_dir_path = env.repo_dir / f"{endpoint.endpoint_id}_restored_datadir" restored_dir_path.mkdir(exist_ok=True) - pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) psql_path = os.path.join(pg_bin.pg_bin_path, "psql") pageserver_id = env.attachment_service.locate(endpoint.tenant_id)[0]["node_id"] @@ -3989,7 +3992,7 @@ def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint {psql_path} \ --no-psqlrc \ postgres://localhost:{env.get_pageserver(pageserver_id).service_port.pg} \ - -c 'basebackup {endpoint.tenant_id} {timeline_id}' \ + -c 'basebackup {endpoint.tenant_id} {timeline_id} {checkpoint_lsn}' \ | tar -x -C {restored_dir_path} """ From 951c9bf4cad6a651f9531f3c4e1e58d90c27910e Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 9 Feb 2024 10:12:40 +0000 Subject: [PATCH 25/81] control_plane: fix shard splitting on unsharded tenant (#6689) ## Problem Previous test started with a new-style TenantShardId with a non-zero ShardCount. We also need to handle the case of a ShardCount() (aka `unsharded`) parent shard. **A followup PR will refactor ShardCount to make its inner value private and thereby make this kind of mistake harder** ## Summary of changes - Fix a place we were incorrectly treating a ShardCount as a number of shards rather than as thing that can be zero or the number of shards. - Add a test for this case. --- .../attachment_service/src/persistence.rs | 10 ++++-- test_runner/regress/test_sharding.py | 31 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index cead540058..623d625767 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -381,16 +381,22 @@ impl Persistence { self.with_conn(move |conn| -> DatabaseResult<()> { conn.transaction(|conn| -> DatabaseResult<()> { // Mark parent shards as splitting + + let expect_parent_records = std::cmp::max(1, old_shard_count.0); + let updated = diesel::update(tenant_shards) .filter(tenant_id.eq(split_tenant_id.to_string())) .filter(shard_count.eq(old_shard_count.0 as i32)) .set((splitting.eq(1),)) .execute(conn)?; - if ShardCount(updated.try_into().map_err(|_| DatabaseError::Logical(format!("Overflow existing shard count {} while splitting", updated)))?) != old_shard_count { + if u8::try_from(updated) + .map_err(|_| DatabaseError::Logical( + format!("Overflow existing shard count {} while splitting", updated)) + )? != expect_parent_records { // Perhaps a deletion or another split raced with this attempt to split, mutating // the parent shards that we intend to split. In this case the split request should fail. return Err(DatabaseError::Logical( - format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {old_shard_count:?})") + format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})") )); } diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 805eaa34b0..27d1cf2f34 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -4,7 +4,7 @@ from fixtures.neon_fixtures import ( tenant_get_shards, ) from fixtures.remote_storage import s3_storage -from fixtures.types import TimelineId +from fixtures.types import TenantShardId, TimelineId from fixtures.workload import Workload @@ -84,6 +84,35 @@ def test_sharding_smoke( assert timelines == {env.initial_timeline, timeline_b} +def test_sharding_split_unsharded( + neon_env_builder: NeonEnvBuilder, +): + """ + Test that shard splitting works on a tenant created as unsharded (i.e. with + ShardCount(0)). + """ + env = neon_env_builder.init_start() + tenant_id = env.initial_tenant + timeline_id = env.initial_timeline + + workload = Workload(env, tenant_id, timeline_id, branch_name="main") + workload.init() + workload.write_rows(256) + + # Check that we created with an unsharded TenantShardId: this is the default, + # but check it in case we change the default in future + assert env.attachment_service.inspect(TenantShardId(tenant_id, 0, 0)) is not None + + # Split one shard into two + env.attachment_service.tenant_shard_split(tenant_id, shard_count=2) + + # Check we got the shard IDs we expected + assert env.attachment_service.inspect(TenantShardId(tenant_id, 0, 2)) is not None + assert env.attachment_service.inspect(TenantShardId(tenant_id, 1, 2)) is not None + + workload.validate() + + def test_sharding_split_smoke( neon_env_builder: NeonEnvBuilder, ): From ea089dc97700732788f2d9f0ea44e10fb59c2f6f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 9 Feb 2024 10:29:20 +0000 Subject: [PATCH 26/81] proxy: add per query array mode flag (#6678) ## Problem Drizzle needs to be able to configure the array_mode flag per query. ## Summary of changes Adds an array_mode flag to the query data json that will otherwise default to the header flag. --- proxy/src/serverless/sql_over_http.rs | 163 ++++++++++++++------------ test_runner/regress/test_proxy.py | 33 ++++++ 2 files changed, 119 insertions(+), 77 deletions(-) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 7092b65f03..25e8813625 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -44,10 +44,13 @@ use super::json::pg_text_row_to_json; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] +#[serde(rename_all = "camelCase")] struct QueryData { query: String, #[serde(deserialize_with = "bytes_to_pg_text")] params: Vec>, + #[serde(default)] + array_mode: Option, } #[derive(serde::Deserialize)] @@ -330,7 +333,7 @@ async fn handle_inner( // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); - let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); + let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in @@ -402,83 +405,87 @@ async fn handle_inner( // Now execute the query and return the result // let mut size = 0; - let result = - match payload { - Payload::Single(stmt) => { - let (status, results) = - query_to_json(&*client, stmt, &mut 0, raw_output, array_mode) - .await - .map_err(|e| { - client.discard(); - e - })?; - client.check_idle(status); - results + let result = match payload { + Payload::Single(stmt) => { + let (status, results) = + query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode) + .await + .map_err(|e| { + client.discard(); + e + })?; + client.check_idle(status); + results + } + Payload::Batch(statements) => { + let (inner, mut discard) = client.inner(); + let mut builder = inner.build_transaction(); + if let Some(isolation_level) = txn_isolation_level { + builder = builder.isolation_level(isolation_level); } - Payload::Batch(statements) => { - let (inner, mut discard) = client.inner(); - let mut builder = inner.build_transaction(); - if let Some(isolation_level) = txn_isolation_level { - builder = builder.isolation_level(isolation_level); - } - if txn_read_only { - builder = builder.read_only(true); - } - if txn_deferrable { - builder = builder.deferrable(true); - } - - let transaction = builder.start().await.map_err(|e| { - // if we cannot start a transaction, we should return immediately - // and not return to the pool. connection is clearly broken - discard.discard(); - e - })?; - - let results = - match query_batch(&transaction, statements, &mut size, raw_output, array_mode) - .await - { - Ok(results) => { - let status = transaction.commit().await.map_err(|e| { - // if we cannot commit - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - e - })?; - discard.check_idle(status); - results - } - Err(err) => { - let status = transaction.rollback().await.map_err(|e| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - e - })?; - discard.check_idle(status); - return Err(err); - } - }; - - if txn_read_only { - response = response.header( - TXN_READ_ONLY.clone(), - HeaderValue::try_from(txn_read_only.to_string())?, - ); - } - if txn_deferrable { - response = response.header( - TXN_DEFERRABLE.clone(), - HeaderValue::try_from(txn_deferrable.to_string())?, - ); - } - if let Some(txn_isolation_level) = txn_isolation_level_raw { - response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); - } - json!({ "results": results }) + if txn_read_only { + builder = builder.read_only(true); } - }; + if txn_deferrable { + builder = builder.deferrable(true); + } + + let transaction = builder.start().await.map_err(|e| { + // if we cannot start a transaction, we should return immediately + // and not return to the pool. connection is clearly broken + discard.discard(); + e + })?; + + let results = match query_batch( + &transaction, + statements, + &mut size, + raw_output, + default_array_mode, + ) + .await + { + Ok(results) => { + let status = transaction.commit().await.map_err(|e| { + // if we cannot commit - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + results + } + Err(err) => { + let status = transaction.rollback().await.map_err(|e| { + // if we cannot rollback - for now don't return connection to pool + // TODO: get a query status from the error + discard.discard(); + e + })?; + discard.check_idle(status); + return Err(err); + } + }; + + if txn_read_only { + response = response.header( + TXN_READ_ONLY.clone(), + HeaderValue::try_from(txn_read_only.to_string())?, + ); + } + if txn_deferrable { + response = response.header( + TXN_DEFERRABLE.clone(), + HeaderValue::try_from(txn_deferrable.to_string())?, + ); + } + if let Some(txn_isolation_level) = txn_isolation_level_raw { + response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); + } + json!({ "results": results }) + } + }; ctx.set_success(); ctx.log(); @@ -524,7 +531,7 @@ async fn query_to_json( data: QueryData, current_size: &mut usize, raw_output: bool, - array_mode: bool, + default_array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; @@ -578,6 +585,8 @@ async fn query_to_json( columns.push(client.get_type(c.type_oid()).await?); } + let array_mode = data.array_mode.unwrap_or(default_array_mode); + // convert rows to JSON let rows = rows .iter() diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index b3b35e446d..49a0450f0c 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -390,6 +390,39 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): assert result[0]["rows"] == [{"answer": 42}] +def test_sql_over_http_batch_output_options(static_proxy: NeonProxy): + static_proxy.safe_psql("create role http with login password 'http' superuser") + + connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" + response = requests.post( + f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql", + data=json.dumps( + { + "queries": [ + {"query": "select $1 as answer", "params": [42], "arrayMode": True}, + {"query": "select $1 as answer", "params": [42], "arrayMode": False}, + ] + } + ), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Batch-Isolation-Level": "Serializable", + "Neon-Batch-Read-Only": "false", + "Neon-Batch-Deferrable": "false", + }, + verify=str(static_proxy.test_output_dir / "proxy.crt"), + ) + assert response.status_code == 200 + results = response.json()["results"] + + assert results[0]["rowAsArray"] + assert results[0]["rows"] == [["42"]] + + assert not results[1]["rowAsArray"] + assert results[1]["rows"] == [{"answer": "42"}] + + def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") From eec1e1a19223750e16401962c978fdeee2a305c8 Mon Sep 17 00:00:00 2001 From: Anastasia Lubennikova Date: Thu, 4 Jan 2024 12:34:15 +0000 Subject: [PATCH 27/81] Pre-install anon extension from compute_ctl if anon is in shared_preload_libraries. Users cannot install it themselves, because superuser is required. GRANT all priveleged needed to use it to db_owner We use the neon fork of the extension, because small change to sql file is needed to allow db_owner to use it. This feature is behind a feature flag AnonExtension, so it is not enabled by default. --- Dockerfile.compute-node | 5 +- compute_tools/src/compute.rs | 14 +++- compute_tools/src/spec.rs | 132 ++++++++++++++++++++++++++++++++++- libs/compute_api/src/spec.rs | 3 + 4 files changed, 149 insertions(+), 5 deletions(-) diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index d91c7cfd72..cc7a110008 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -639,8 +639,8 @@ FROM build-deps AS pg-anon-pg-build COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ ENV PATH "/usr/local/pgsql/bin/:$PATH" -RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/1.1.0/postgresql_anonymizer-1.1.0.tar.gz -O pg_anon.tar.gz && \ - echo "08b09d2ff9b962f96c60db7e6f8e79cf7253eb8772516998fc35ece08633d3ad pg_anon.tar.gz" | sha256sum --check && \ +RUN wget https://github.com/neondatabase/postgresql_anonymizer/archive/refs/tags/neon_1.1.1.tar.gz -O pg_anon.tar.gz && \ + echo "321ea8d5c1648880aafde850a2c576e4a9e7b9933a34ce272efc839328999fa9 pg_anon.tar.gz" | sha256sum --check && \ mkdir pg_anon-src && cd pg_anon-src && tar xvzf ../pg_anon.tar.gz --strip-components=1 -C . && \ find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\ make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ @@ -809,6 +809,7 @@ COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pg-semver-pg-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql +COPY --from=pg-anon-pg-build /usr/local/pgsql/ /usr/local/pgsql/ COPY pgxn/ pgxn/ RUN make -j $(getconf _NPROCESSORS_ONLN) \ diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 0ca1a47fbf..993b5725a4 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -765,7 +765,12 @@ impl ComputeNode { handle_roles(spec, &mut client)?; handle_databases(spec, &mut client)?; handle_role_deletions(spec, connstr.as_str(), &mut client)?; - handle_grants(spec, &mut client, connstr.as_str())?; + handle_grants( + spec, + &mut client, + connstr.as_str(), + self.has_feature(ComputeFeature::AnonExtension), + )?; handle_extensions(spec, &mut client)?; handle_extension_neon(&mut client)?; create_availability_check_data(&mut client)?; @@ -839,7 +844,12 @@ impl ComputeNode { handle_roles(&spec, &mut client)?; handle_databases(&spec, &mut client)?; handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?; - handle_grants(&spec, &mut client, self.connstr.as_str())?; + handle_grants( + &spec, + &mut client, + self.connstr.as_str(), + self.has_feature(ComputeFeature::AnonExtension), + )?; handle_extensions(&spec, &mut client)?; handle_extension_neon(&mut client)?; // We can skip handle_migrations here because a new migration can only appear diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 2b1bff75fe..3df5f10e23 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -581,7 +581,12 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> { /// Grant CREATE ON DATABASE to the database owner and do some other alters and grants /// to allow users creating trusted extensions and re-creating `public` schema, for example. #[instrument(skip_all)] -pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> { +pub fn handle_grants( + spec: &ComputeSpec, + client: &mut Client, + connstr: &str, + enable_anon_extension: bool, +) -> Result<()> { info!("modifying database permissions"); let existing_dbs = get_existing_dbs(client)?; @@ -678,6 +683,11 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> inlinify(&grant_query) ); db_client.simple_query(&grant_query)?; + + // it is important to run this after all grants + if enable_anon_extension { + handle_extension_anon(spec, &db.owner, &mut db_client, false)?; + } } Ok(()) @@ -809,5 +819,125 @@ $$;"#, "Ran {} migrations", (migrations.len() - starting_migration_id) ); + + Ok(()) +} + +/// Connect to the database as superuser and pre-create anon extension +/// if it is present in shared_preload_libraries +#[instrument(skip_all)] +pub fn handle_extension_anon( + spec: &ComputeSpec, + db_owner: &str, + db_client: &mut Client, + grants_only: bool, +) -> Result<()> { + info!("handle extension anon"); + + if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") { + if libs.contains("anon") { + if !grants_only { + // check if extension is already initialized using anon.is_initialized() + let query = "SELECT anon.is_initialized()"; + match db_client.query(query, &[]) { + Ok(rows) => { + if !rows.is_empty() { + let is_initialized: bool = rows[0].get(0); + if is_initialized { + info!("anon extension is already initialized"); + return Ok(()); + } + } + } + Err(e) => { + warn!( + "anon extension is_installed check failed with expected error: {}", + e + ); + } + }; + + // Create anon extension if this compute needs it + // Users cannot create it themselves, because superuser is required. + let mut query = "CREATE EXTENSION IF NOT EXISTS anon CASCADE"; + info!("creating anon extension with query: {}", query); + match db_client.query(query, &[]) { + Ok(_) => {} + Err(e) => { + error!("anon extension creation failed with error: {}", e); + return Ok(()); + } + } + + // check that extension is installed + query = "SELECT extname FROM pg_extension WHERE extname = 'anon'"; + let rows = db_client.query(query, &[])?; + if rows.is_empty() { + error!("anon extension is not installed"); + return Ok(()); + } + + // Initialize anon extension + // This also requires superuser privileges, so users cannot do it themselves. + query = "SELECT anon.init()"; + match db_client.query(query, &[]) { + Ok(_) => {} + Err(e) => { + error!("anon.init() failed with error: {}", e); + return Ok(()); + } + } + } + + // check that extension is installed, if not bail early + let query = "SELECT extname FROM pg_extension WHERE extname = 'anon'"; + match db_client.query(query, &[]) { + Ok(rows) => { + if rows.is_empty() { + error!("anon extension is not installed"); + return Ok(()); + } + } + Err(e) => { + error!("anon extension check failed with error: {}", e); + return Ok(()); + } + }; + + let query = format!("GRANT ALL ON SCHEMA anon TO {}", db_owner); + info!("granting anon extension permissions with query: {}", query); + db_client.simple_query(&query)?; + + // Grant permissions to db_owner to use anon extension functions + let query = format!("GRANT ALL ON ALL FUNCTIONS IN SCHEMA anon TO {}", db_owner); + info!("granting anon extension permissions with query: {}", query); + db_client.simple_query(&query)?; + + // This is needed, because some functions are defined as SECURITY DEFINER. + // In Postgres SECURITY DEFINER functions are executed with the privileges + // of the owner. + // In anon extension this it is needed to access some GUCs, which are only accessible to + // superuser. But we've patched postgres to allow db_owner to access them as well. + // So we need to change owner of these functions to db_owner. + let query = format!(" + SELECT 'ALTER FUNCTION '||nsp.nspname||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||') OWNER TO {};' + from pg_proc p + join pg_namespace nsp ON p.pronamespace = nsp.oid + where nsp.nspname = 'anon';", db_owner); + + info!("change anon extension functions owner to db owner"); + db_client.simple_query(&query)?; + + // affects views as well + let query = format!("GRANT ALL ON ALL TABLES IN SCHEMA anon TO {}", db_owner); + info!("granting anon extension permissions with query: {}", query); + db_client.simple_query(&query)?; + + let query = format!("GRANT ALL ON ALL SEQUENCES IN SCHEMA anon TO {}", db_owner); + info!("granting anon extension permissions with query: {}", query); + db_client.simple_query(&query)?; + } + } + Ok(()) } diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 13ac18e0c5..2f412b61a3 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -90,6 +90,9 @@ pub enum ComputeFeature { /// track short-lived connections as user activity. ActivityMonitorExperimental, + /// Pre-install and initialize anon extension for every database in the cluster + AnonExtension, + /// This is a special feature flag that is used to represent unknown feature flags. /// Basically all unknown to enum flags are represented as this one. See unit test /// `parse_unknown_features()` for more details. From eb919cab88b8a28eb423b33eb07a858acbd61eab Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Fri, 9 Feb 2024 14:52:58 +0200 Subject: [PATCH 28/81] prepare to move timeouts and cancellation handling to remote_storage (#6696) This PR is preliminary cleanups and refactoring around `remote_storage` for next PR which will move the timeouts and cancellation into `remote_storage`. Summary: - smaller drive-by fixes - code simplification - refactor common parts like `DownloadError::is_permanent` - align error types with `RemoteStorage::list_*` to use more `download_retry` helper Cc: #6096 --- libs/remote_storage/src/lib.rs | 26 ++++++- libs/remote_storage/src/local_fs.rs | 50 ++++++++---- libs/remote_storage/src/s3_bucket.rs | 77 ++++++------------- libs/remote_storage/src/simulate_failures.rs | 28 ++++--- libs/remote_storage/src/support.rs | 33 ++++++++ pageserver/src/task_mgr.rs | 4 +- pageserver/src/tenant.rs | 4 +- .../src/tenant/remote_timeline_client.rs | 35 ++++----- .../tenant/remote_timeline_client/download.rs | 59 +++++--------- pageserver/src/tenant/secondary/downloader.rs | 2 +- 10 files changed, 175 insertions(+), 143 deletions(-) create mode 100644 libs/remote_storage/src/support.rs diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index e64b1de6f9..b6648931ac 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -13,6 +13,7 @@ mod azure_blob; mod local_fs; mod s3_bucket; mod simulate_failures; +mod support; use std::{ collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime, @@ -170,7 +171,10 @@ pub trait RemoteStorage: Send + Sync + 'static { /// whereas, /// list_prefixes("foo/bar/") = ["cat", "dog"] /// See `test_real_s3.rs` for more details. - async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result> { + async fn list_files( + &self, + prefix: Option<&RemotePath>, + ) -> Result, DownloadError> { let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys; Ok(result) } @@ -179,7 +183,7 @@ pub trait RemoteStorage: Send + Sync + 'static { &self, prefix: Option<&RemotePath>, _mode: ListingMode, - ) -> anyhow::Result; + ) -> Result; /// Streams the local file contents into remote into the remote storage entry. async fn upload( @@ -269,6 +273,19 @@ impl std::fmt::Display for DownloadError { impl std::error::Error for DownloadError {} +impl DownloadError { + /// Returns true if the error should not be retried with backoff + pub fn is_permanent(&self) -> bool { + use DownloadError::*; + match self { + BadInput(_) => true, + NotFound => true, + Cancelled => true, + Other(_) => false, + } + } +} + #[derive(Debug)] pub enum TimeTravelError { /// Validation or other error happened due to user input. @@ -336,7 +353,10 @@ impl GenericRemoteStorage> { // A function for listing all the files in a "directory" // Example: // list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"] - pub async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result> { + pub async fn list_files( + &self, + folder: Option<&RemotePath>, + ) -> Result, DownloadError> { match self { Self::LocalFs(s) => s.list_files(folder).await, Self::AwsS3(s) => s.list_files(folder).await, diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 36ec15e1b1..3ebea76181 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -18,9 +18,7 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken}; use tracing::*; use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty}; -use crate::{ - Download, DownloadError, DownloadStream, Listing, ListingMode, RemotePath, TimeTravelError, -}; +use crate::{Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError}; use super::{RemoteStorage, StorageMetadata}; @@ -365,27 +363,33 @@ impl RemoteStorage for LocalFs { format!("Failed to open source file {target_path:?} to use in the download") }) .map_err(DownloadError::Other)?; + + let len = source + .metadata() + .await + .context("query file length") + .map_err(DownloadError::Other)? + .len(); + source .seek(io::SeekFrom::Start(start_inclusive)) .await .context("Failed to seek to the range start in a local storage file") .map_err(DownloadError::Other)?; + let metadata = self .read_storage_metadata(&target_path) .await .map_err(DownloadError::Other)?; - let download_stream: DownloadStream = match end_exclusive { - Some(end_exclusive) => Box::pin(ReaderStream::new( - source.take(end_exclusive - start_inclusive), - )), - None => Box::pin(ReaderStream::new(source)), - }; + let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive); + let source = ReaderStream::new(source); + Ok(Download { metadata, last_modified: None, etag: None, - download_stream, + download_stream: Box::pin(source), }) } else { Err(DownloadError::NotFound) @@ -514,10 +518,8 @@ mod fs_tests { use futures_util::Stream; use std::{collections::HashMap, io::Write}; - async fn read_and_assert_remote_file_contents( + async fn read_and_check_metadata( storage: &LocalFs, - #[allow(clippy::ptr_arg)] - // have to use &Utf8PathBuf due to `storage.local_path` parameter requirements remote_storage_path: &RemotePath, expected_metadata: Option<&StorageMetadata>, ) -> anyhow::Result { @@ -596,7 +598,7 @@ mod fs_tests { let upload_name = "upload_1"; let upload_target = upload_dummy_file(&storage, upload_name, None).await?; - let contents = read_and_assert_remote_file_contents(&storage, &upload_target, None).await?; + let contents = read_and_check_metadata(&storage, &upload_target, None).await?; assert_eq!( dummy_contents(upload_name), contents, @@ -618,7 +620,7 @@ mod fs_tests { let upload_target = upload_dummy_file(&storage, upload_name, None).await?; let full_range_download_contents = - read_and_assert_remote_file_contents(&storage, &upload_target, None).await?; + read_and_check_metadata(&storage, &upload_target, None).await?; assert_eq!( dummy_contents(upload_name), full_range_download_contents, @@ -660,6 +662,22 @@ mod fs_tests { "Second part bytes should be returned when requested" ); + let suffix_bytes = storage + .download_byte_range(&upload_target, 13, None) + .await? + .download_stream; + let suffix_bytes = aggregate(suffix_bytes).await?; + let suffix = std::str::from_utf8(&suffix_bytes)?; + assert_eq!(upload_name, suffix); + + let all_bytes = storage + .download_byte_range(&upload_target, 0, None) + .await? + .download_stream; + let all_bytes = aggregate(all_bytes).await?; + let all_bytes = std::str::from_utf8(&all_bytes)?; + assert_eq!(dummy_contents("upload_1"), all_bytes); + Ok(()) } @@ -736,7 +754,7 @@ mod fs_tests { upload_dummy_file(&storage, upload_name, Some(metadata.clone())).await?; let full_range_download_contents = - read_and_assert_remote_file_contents(&storage, &upload_target, Some(&metadata)).await?; + read_and_check_metadata(&storage, &upload_target, Some(&metadata)).await?; assert_eq!( dummy_contents(upload_name), full_range_download_contents, diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index c9ad9ef225..2b33a6ffd1 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -45,8 +45,9 @@ use utils::backoff; use super::StorageMetadata; use crate::{ - ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, - S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR, + support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, + RemotePath, RemoteStorage, S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, + REMOTE_STORAGE_PREFIX_SEPARATOR, }; pub(super) mod metrics; @@ -63,7 +64,6 @@ pub struct S3Bucket { concurrency_limiter: ConcurrencyLimiter, } -#[derive(Default)] struct GetObjectRequest { bucket: String, key: String, @@ -232,24 +232,8 @@ impl S3Bucket { let started_at = ScopeGuard::into_inner(started_at); - match get_object { - Ok(object_output) => { - let metadata = object_output.metadata().cloned().map(StorageMetadata); - let etag = object_output.e_tag.clone(); - let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok()); - - let body = object_output.body; - let body = ByteStreamAsStream::from(body); - let body = PermitCarrying::new(permit, body); - let body = TimedDownload::new(started_at, body); - - Ok(Download { - metadata, - etag, - last_modified, - download_stream: Box::pin(body), - }) - } + let object_output = match get_object { + Ok(object_output) => object_output, Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => { // Count this in the AttemptOutcome::Ok bucket, because 404 is not // an error: we expect to sometimes fetch an object and find it missing, @@ -259,7 +243,7 @@ impl S3Bucket { AttemptOutcome::Ok, started_at, ); - Err(DownloadError::NotFound) + return Err(DownloadError::NotFound); } Err(e) => { metrics::BUCKET_METRICS.req_seconds.observe_elapsed( @@ -268,11 +252,27 @@ impl S3Bucket { started_at, ); - Err(DownloadError::Other( + return Err(DownloadError::Other( anyhow::Error::new(e).context("download s3 object"), - )) + )); } - } + }; + + let metadata = object_output.metadata().cloned().map(StorageMetadata); + let etag = object_output.e_tag; + let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok()); + + let body = object_output.body; + let body = ByteStreamAsStream::from(body); + let body = PermitCarrying::new(permit, body); + let body = TimedDownload::new(started_at, body); + + Ok(Download { + metadata, + etag, + last_modified, + download_stream: Box::pin(body), + }) } async fn delete_oids( @@ -354,33 +354,6 @@ impl Stream for ByteStreamAsStream { // sense and Stream::size_hint does not really } -pin_project_lite::pin_project! { - /// An `AsyncRead` adapter which carries a permit for the lifetime of the value. - struct PermitCarrying { - permit: tokio::sync::OwnedSemaphorePermit, - #[pin] - inner: S, - } -} - -impl PermitCarrying { - fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self { - Self { permit, inner } - } -} - -impl>> Stream for PermitCarrying { - type Item = ::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx) - } - - fn size_hint(&self) -> (usize, Option) { - self.inner.size_hint() - } -} - pin_project_lite::pin_project! { /// Times and tracks the outcome of the request. struct TimedDownload { diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 82d5a61fda..14bdb5ed4d 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -60,7 +60,7 @@ impl UnreliableWrapper { /// On the first attempts of this operation, return an error. After 'attempts_to_fail' /// attempts, let the operation go ahead, and clear the counter. /// - fn attempt(&self, op: RemoteOp) -> Result { + fn attempt(&self, op: RemoteOp) -> anyhow::Result { let mut attempts = self.attempts.lock().unwrap(); match attempts.entry(op) { @@ -78,13 +78,13 @@ impl UnreliableWrapper { } else { let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key()); - Err(DownloadError::Other(error)) + Err(error) } } Entry::Vacant(e) => { let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key()); e.insert(1); - Err(DownloadError::Other(error)) + Err(error) } } } @@ -105,12 +105,17 @@ impl RemoteStorage for UnreliableWrapper { &self, prefix: Option<&RemotePath>, ) -> Result, DownloadError> { - self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?; + self.attempt(RemoteOp::ListPrefixes(prefix.cloned())) + .map_err(DownloadError::Other)?; self.inner.list_prefixes(prefix).await } - async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result> { - self.attempt(RemoteOp::ListPrefixes(folder.cloned()))?; + async fn list_files( + &self, + folder: Option<&RemotePath>, + ) -> Result, DownloadError> { + self.attempt(RemoteOp::ListPrefixes(folder.cloned())) + .map_err(DownloadError::Other)?; self.inner.list_files(folder).await } @@ -119,7 +124,8 @@ impl RemoteStorage for UnreliableWrapper { prefix: Option<&RemotePath>, mode: ListingMode, ) -> Result { - self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?; + self.attempt(RemoteOp::ListPrefixes(prefix.cloned())) + .map_err(DownloadError::Other)?; self.inner.list(prefix, mode).await } @@ -137,7 +143,8 @@ impl RemoteStorage for UnreliableWrapper { } async fn download(&self, from: &RemotePath) -> Result { - self.attempt(RemoteOp::Download(from.clone()))?; + self.attempt(RemoteOp::Download(from.clone())) + .map_err(DownloadError::Other)?; self.inner.download(from).await } @@ -150,7 +157,8 @@ impl RemoteStorage for UnreliableWrapper { // Note: We treat any download_byte_range as an "attempt" of the same // operation. We don't pay attention to the ranges. That's good enough // for now. - self.attempt(RemoteOp::Download(from.clone()))?; + self.attempt(RemoteOp::Download(from.clone())) + .map_err(DownloadError::Other)?; self.inner .download_byte_range(from, start_inclusive, end_exclusive) .await @@ -193,7 +201,7 @@ impl RemoteStorage for UnreliableWrapper { cancel: &CancellationToken, ) -> Result<(), TimeTravelError> { self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned()))) - .map_err(|e| TimeTravelError::Other(anyhow::Error::new(e)))?; + .map_err(TimeTravelError::Other)?; self.inner .time_travel_recover(prefix, timestamp, done_if_after, cancel) .await diff --git a/libs/remote_storage/src/support.rs b/libs/remote_storage/src/support.rs new file mode 100644 index 0000000000..4688a484a5 --- /dev/null +++ b/libs/remote_storage/src/support.rs @@ -0,0 +1,33 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::Stream; + +pin_project_lite::pin_project! { + /// An `AsyncRead` adapter which carries a permit for the lifetime of the value. + pub(crate) struct PermitCarrying { + permit: tokio::sync::OwnedSemaphorePermit, + #[pin] + inner: S, + } +} + +impl PermitCarrying { + pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self { + Self { permit, inner } + } +} + +impl Stream for PermitCarrying { + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 5a06a97525..3cec5fa850 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -576,8 +576,8 @@ pub fn shutdown_token() -> CancellationToken { /// Has the current task been requested to shut down? pub fn is_shutdown_requested() -> bool { - if let Ok(cancel) = SHUTDOWN_TOKEN.try_with(|t| t.clone()) { - cancel.is_cancelled() + if let Ok(true_or_false) = SHUTDOWN_TOKEN.try_with(|t| t.is_cancelled()) { + true_or_false } else { if !cfg!(test) { warn!("is_shutdown_requested() called in an unexpected task or thread"); diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index f086f46213..4446c410b0 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -1377,7 +1377,7 @@ impl Tenant { async move { debug!("starting index part download"); - let index_part = client.download_index_file(cancel_clone).await; + let index_part = client.download_index_file(&cancel_clone).await; debug!("finished index part download"); @@ -2434,7 +2434,7 @@ impl Tenant { // operation is rare, so it's simpler to just download it (and robustly guarantees that the index // we use here really is the remotely persistent one). let result = tl_client - .download_index_file(self.cancel.clone()) + .download_index_file(&self.cancel) .instrument(info_span!("download_index_file", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id)) .await?; let index_part = match result { diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 152c9a2b7d..0c7dd68c3f 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -217,6 +217,7 @@ use crate::metrics::{ }; use crate::task_mgr::shutdown_token; use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id; +use crate::tenant::remote_timeline_client::download::download_retry; use crate::tenant::storage_layer::AsLayerDesc; use crate::tenant::upload_queue::Delete; use crate::tenant::TIMELINES_SEGMENT_NAME; @@ -262,6 +263,11 @@ pub(crate) const INITDB_PRESERVED_PATH: &str = "initdb-preserved.tar.zst"; /// Default buffer size when interfacing with [`tokio::fs::File`]. pub(crate) const BUFFER_SIZE: usize = 32 * 1024; +/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not +/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that. +pub(crate) const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120); +pub(crate) const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120); + pub enum MaybeDeletedIndexPart { IndexPart(IndexPart), Deleted(IndexPart), @@ -325,11 +331,6 @@ pub struct RemoteTimelineClient { cancel: CancellationToken, } -/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not -/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that. -const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120); -const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120); - /// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to anyhow. /// /// This is a convenience for the various upload functions. In future @@ -506,7 +507,7 @@ impl RemoteTimelineClient { /// Download index file pub async fn download_index_file( &self, - cancel: CancellationToken, + cancel: &CancellationToken, ) -> Result { let _unfinished_gauge_guard = self.metrics.call_begin( &RemoteOpFileKind::Index, @@ -1147,22 +1148,17 @@ impl RemoteTimelineClient { let cancel = shutdown_token(); - let remaining = backoff::retry( + let remaining = download_retry( || async { self.storage_impl .list_files(Some(&timeline_storage_path)) .await }, - |_e| false, - FAILED_DOWNLOAD_WARN_THRESHOLD, - FAILED_REMOTE_OP_RETRIES, - "list_prefixes", + "list remaining files", &cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled!")) - .and_then(|x| x) - .context("list prefixes")?; + .context("list files remaining files")?; // We will delete the current index_part object last, since it acts as a deletion // marker via its deleted_at attribute @@ -1351,6 +1347,7 @@ impl RemoteTimelineClient { /// queue. /// async fn perform_upload_task(self: &Arc, task: Arc) { + let cancel = shutdown_token(); // Loop to retry until it completes. loop { // If we're requested to shut down, close up shop and exit. @@ -1362,7 +1359,7 @@ impl RemoteTimelineClient { // the Future, but we're not 100% sure if the remote storage library // is cancellation safe, so we don't dare to do that. Hopefully, the // upload finishes or times out soon enough. - if task_mgr::is_shutdown_requested() { + if cancel.is_cancelled() { info!("upload task cancelled by shutdown request"); match self.stop() { Ok(()) => {} @@ -1473,7 +1470,7 @@ impl RemoteTimelineClient { retries, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, - &shutdown_token(), + &cancel, ) .await; } @@ -1990,7 +1987,7 @@ mod tests { // Download back the index.json, and check that the list of files is correct let initial_index_part = match client - .download_index_file(CancellationToken::new()) + .download_index_file(&CancellationToken::new()) .await .unwrap() { @@ -2084,7 +2081,7 @@ mod tests { // Download back the index.json, and check that the list of files is correct let index_part = match client - .download_index_file(CancellationToken::new()) + .download_index_file(&CancellationToken::new()) .await .unwrap() { @@ -2286,7 +2283,7 @@ mod tests { let client = test_state.build_client(get_generation); let download_r = client - .download_index_file(CancellationToken::new()) + .download_index_file(&CancellationToken::new()) .await .expect("download should always succeed"); assert!(matches!(download_r, MaybeDeletedIndexPart::IndexPart(_))); diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index 6c1125746b..33287fc8f4 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -216,16 +216,15 @@ pub async fn list_remote_timelines( anyhow::bail!("storage-sync-list-remote-timelines"); }); - let cancel_inner = cancel.clone(); let listing = download_retry_forever( || { download_cancellable( - &cancel_inner, + &cancel, storage.list(Some(&remote_path), ListingMode::WithDelimiter), ) }, &format!("list timelines for {tenant_shard_id}"), - cancel, + &cancel, ) .await?; @@ -258,19 +257,18 @@ async fn do_download_index_part( tenant_shard_id: &TenantShardId, timeline_id: &TimelineId, index_generation: Generation, - cancel: CancellationToken, + cancel: &CancellationToken, ) -> Result { use futures::stream::StreamExt; let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation); - let cancel_inner = cancel.clone(); let index_part_bytes = download_retry_forever( || async { // Cancellation: if is safe to cancel this future because we're just downloading into // a memory buffer, not touching local disk. let index_part_download = - download_cancellable(&cancel_inner, storage.download(&remote_path)).await?; + download_cancellable(cancel, storage.download(&remote_path)).await?; let mut index_part_bytes = Vec::new(); let mut stream = std::pin::pin!(index_part_download.download_stream); @@ -288,7 +286,7 @@ async fn do_download_index_part( .await?; let index_part: IndexPart = serde_json::from_slice(&index_part_bytes) - .with_context(|| format!("download index part file at {remote_path:?}")) + .with_context(|| format!("deserialize index part file at {remote_path:?}")) .map_err(DownloadError::Other)?; Ok(index_part) @@ -305,7 +303,7 @@ pub(super) async fn download_index_part( tenant_shard_id: &TenantShardId, timeline_id: &TimelineId, my_generation: Generation, - cancel: CancellationToken, + cancel: &CancellationToken, ) -> Result { debug_assert_current_span_has_tenant_and_timeline_id(); @@ -325,14 +323,8 @@ pub(super) async fn download_index_part( // index in our generation. // // This is an optimization to avoid doing the listing for the general case below. - let res = do_download_index_part( - storage, - tenant_shard_id, - timeline_id, - my_generation, - cancel.clone(), - ) - .await; + let res = + do_download_index_part(storage, tenant_shard_id, timeline_id, my_generation, cancel).await; match res { Ok(index_part) => { tracing::debug!( @@ -357,7 +349,7 @@ pub(super) async fn download_index_part( tenant_shard_id, timeline_id, my_generation.previous(), - cancel.clone(), + cancel, ) .await; match res { @@ -379,18 +371,13 @@ pub(super) async fn download_index_part( // objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent // to constructing a full index path with no generation, because the generation is a suffix. let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none()); - let indices = backoff::retry( + + let indices = download_retry( || async { storage.list_files(Some(&index_prefix)).await }, - |_| false, - FAILED_DOWNLOAD_WARN_THRESHOLD, - FAILED_REMOTE_OP_RETRIES, - "listing index_part files", - &cancel, + "list index_part files", + cancel, ) - .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) - .and_then(|x| x) - .map_err(DownloadError::Other)?; + .await?; // General case logic for which index to use: the latest index whose generation // is <= our own. See "Finding the remote indices for timelines" in docs/rfcs/025-generation-numbers.md @@ -447,8 +434,6 @@ pub(crate) async fn download_initdb_tar_zst( "{INITDB_PATH}.download-{timeline_id}.{TEMP_FILE_SUFFIX}" )); - let cancel_inner = cancel.clone(); - let file = download_retry( || async { let file = OpenOptions::new() @@ -461,13 +446,11 @@ pub(crate) async fn download_initdb_tar_zst( .with_context(|| format!("tempfile creation {temp_path}")) .map_err(DownloadError::Other)?; - let download = match download_cancellable(&cancel_inner, storage.download(&remote_path)) - .await + let download = match download_cancellable(cancel, storage.download(&remote_path)).await { Ok(dl) => dl, Err(DownloadError::NotFound) => { - download_cancellable(&cancel_inner, storage.download(&remote_preserved_path)) - .await? + download_cancellable(cancel, storage.download(&remote_preserved_path)).await? } Err(other) => Err(other)?, }; @@ -516,7 +499,7 @@ pub(crate) async fn download_initdb_tar_zst( /// with backoff. /// /// (See similar logic for uploads in `perform_upload_task`) -async fn download_retry( +pub(super) async fn download_retry( op: O, description: &str, cancel: &CancellationToken, @@ -527,7 +510,7 @@ where { backoff::retry( op, - |e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound), + DownloadError::is_permanent, FAILED_DOWNLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, description, @@ -541,7 +524,7 @@ where async fn download_retry_forever( op: O, description: &str, - cancel: CancellationToken, + cancel: &CancellationToken, ) -> Result where O: FnMut() -> F, @@ -549,11 +532,11 @@ where { backoff::retry( op, - |e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound), + DownloadError::is_permanent, FAILED_DOWNLOAD_WARN_THRESHOLD, u32::MAX, description, - &cancel, + cancel, ) .await .ok_or_else(|| DownloadError::Cancelled) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 9330edf946..0666e104f8 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -533,7 +533,7 @@ impl<'a> TenantDownloader<'a> { .map_err(UpdateError::from)?; let mut heatmap_bytes = Vec::new(); let mut body = tokio_util::io::StreamReader::new(download.download_stream); - let _size = tokio::io::copy(&mut body, &mut heatmap_bytes).await?; + let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?; Ok(heatmap_bytes) }, |e| matches!(e, UpdateError::NoData | UpdateError::Cancelled), From 8d98981fe580fcdfb7066a5698c2448af0cbc61d Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 9 Feb 2024 13:20:04 +0000 Subject: [PATCH 29/81] tests: deflake test_sharding_split_unsharded (#6699) ## Problem This test was a subset of the larger sharding test, and it missed the validate() call on workload that was implicitly waiting for a tenant to become active before trying to split it. It could therefore fail to split due to tenant not yet being active. ## Summary of changes - Insert .validate() call, and move the Workload setup to after the check of shard ID (as the shard ID check should pass immediately) --- test_runner/regress/test_sharding.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 27d1cf2f34..fa40219d0e 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -95,14 +95,15 @@ def test_sharding_split_unsharded( tenant_id = env.initial_tenant timeline_id = env.initial_timeline - workload = Workload(env, tenant_id, timeline_id, branch_name="main") - workload.init() - workload.write_rows(256) - # Check that we created with an unsharded TenantShardId: this is the default, # but check it in case we change the default in future assert env.attachment_service.inspect(TenantShardId(tenant_id, 0, 0)) is not None + workload = Workload(env, tenant_id, timeline_id, branch_name="main") + workload.init() + workload.write_rows(256) + workload.validate() + # Split one shard into two env.attachment_service.tenant_shard_split(tenant_id, shard_count=2) From 84a0e7b022e37b041004e7d9299060a3777c63eb Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Fri, 9 Feb 2024 11:07:42 +0200 Subject: [PATCH 30/81] tests: Allow setting shutdown mode separately from 'destroy' flag In neon_local, the default mode is now always 'fast', regardless of 'destroy'. You can override it with the "neon_local endpoint stop --mode=immediate" flag. In python tests, we still default to 'immediate' mode when using the stop_and_destroy() function, and 'fast' with plain stop(). I kept that to avoid changing behavior in existing tests. I don't think existing tests depend on it, but I wasn't 100% certain. --- control_plane/src/bin/neon_local.rs | 16 +++++++++++++--- control_plane/src/endpoint.rs | 18 ++---------------- test_runner/fixtures/neon_fixtures.py | 11 +++++++---- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index b9af467fdf..d71cdf02c0 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -1014,12 +1014,13 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re .get_one::("endpoint_id") .ok_or_else(|| anyhow!("No endpoint ID was provided to stop"))?; let destroy = sub_args.get_flag("destroy"); + let mode = sub_args.get_one::("mode").expect("has a default"); let endpoint = cplane .endpoints .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - endpoint.stop(destroy)?; + endpoint.stop(mode, destroy)?; } _ => bail!("Unexpected endpoint subcommand '{sub_name}'"), @@ -1303,7 +1304,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) { match ComputeControlPlane::load(env.clone()) { Ok(cplane) => { for (_k, node) in cplane.endpoints { - if let Err(e) = node.stop(false) { + if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) { eprintln!("postgres stop failed: {e:#}"); } } @@ -1652,7 +1653,16 @@ fn cli() -> Command { .long("destroy") .action(ArgAction::SetTrue) .required(false) - ) + ) + .arg( + Arg::new("mode") + .help("Postgres shutdown mode, passed to \"pg_ctl -m \"") + .long("mode") + .action(ArgAction::Set) + .required(false) + .value_parser(["smart", "fast", "immediate"]) + .default_value("fast") + ) ) ) diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index b19a6a1a18..f1fe12e05f 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -761,22 +761,8 @@ impl Endpoint { } } - pub fn stop(&self, destroy: bool) -> Result<()> { - // If we are going to destroy data directory, - // use immediate shutdown mode, otherwise, - // shutdown gracefully to leave the data directory sane. - // - // Postgres is always started from scratch, so stop - // without destroy only used for testing and debugging. - // - self.pg_ctl( - if destroy { - &["-m", "immediate", "stop"] - } else { - &["stop"] - }, - &None, - )?; + pub fn stop(&self, mode: &str, destroy: bool) -> Result<()> { + self.pg_ctl(&["-m", mode, "stop"], &None)?; // Also wait for the compute_ctl process to die. It might have some // cleanup work to do after postgres stops, like syncing safekeepers, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index a6aff77ddf..9996853525 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1816,6 +1816,7 @@ class NeonCli(AbstractNeonCli): endpoint_id: str, destroy=False, check_return_code=True, + mode: Optional[str] = None, ) -> "subprocess.CompletedProcess[str]": args = [ "endpoint", @@ -1823,6 +1824,8 @@ class NeonCli(AbstractNeonCli): ] if destroy: args.append("--destroy") + if mode is not None: + args.append(f"--mode={mode}") if endpoint_id is not None: args.append(endpoint_id) @@ -3162,7 +3165,7 @@ class Endpoint(PgProtocol): with open(remote_extensions_spec_path, "w") as file: json.dump(spec, file, indent=4) - def stop(self) -> "Endpoint": + def stop(self, mode: str = "fast") -> "Endpoint": """ Stop the Postgres instance if it's running. Returns self. @@ -3171,13 +3174,13 @@ class Endpoint(PgProtocol): if self.running: assert self.endpoint_id is not None self.env.neon_cli.endpoint_stop( - self.endpoint_id, check_return_code=self.check_stop_result + self.endpoint_id, check_return_code=self.check_stop_result, mode=mode ) self.running = False return self - def stop_and_destroy(self) -> "Endpoint": + def stop_and_destroy(self, mode: str = "immediate") -> "Endpoint": """ Stop the Postgres instance, then destroy the endpoint. Returns self. @@ -3185,7 +3188,7 @@ class Endpoint(PgProtocol): assert self.endpoint_id is not None self.env.neon_cli.endpoint_stop( - self.endpoint_id, True, check_return_code=self.check_stop_result + self.endpoint_id, True, check_return_code=self.check_stop_result, mode=mode ) self.endpoint_id = None self.running = False From 5239cdc29fdfe8458798cefad51f8871108f9811 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Fri, 9 Feb 2024 11:07:47 +0200 Subject: [PATCH 31/81] Fix test_vm_bit_clear_on_heap_lock test The test was supposed to reproduce the bug fixed in commit 66fa176cc8, i.e. that the clearing of the VM bit was not replayed in the pageserver on HEAP_LOCK records. But it was broken in many ways and failed to reproduce the original problem if you reverted the fix: - The comparison of XIDs was broken. The test read the XID in to a variable in python, but it was treated as a string rather than an integer. As a result, e.g. "999" > "1000". - The test accessed the locked tuple too early, in the loop. Accessing it early, before the pg_xact page had been removed, set the hint bits. That masked the problem on subsequent accesses. - The on-demand SLRU download that was introduced in commit 9a9d9beaee hid the issue. Even though an SLRU segment was removed by Postgres, when it later tried to access it, it could still download it from the pageserver. To ensure that doesn't happen, shorten the GC period and compact and GC aggressively in the test. I also added a more direct check that the VM page is updated, using the get_page_at_lsn() debugging function. Right after locking the row, we now fetch the VM page from pageserver and directly compare it with the VM page in the page cache. They should match. That assertion is more robust to things like on-demand SLRU download that could mask the bug. --- test_runner/regress/test_vm_bits.py | 118 +++++++++++++++++----------- 1 file changed, 72 insertions(+), 46 deletions(-) diff --git a/test_runner/regress/test_vm_bits.py b/test_runner/regress/test_vm_bits.py index 415f086bd3..06c30b8d81 100644 --- a/test_runner/regress/test_vm_bits.py +++ b/test_runner/regress/test_vm_bits.py @@ -1,6 +1,7 @@ -import pytest +import time + from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnv, fork_at_current_lsn +from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, fork_at_current_lsn # @@ -118,12 +119,20 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv): # Test that the ALL_FROZEN VM bit is cleared correctly at a HEAP_LOCK # record. # -# FIXME: This test is broken -@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/6412#issuecomment-1902072541") -def test_vm_bit_clear_on_heap_lock(neon_simple_env: NeonEnv): - env = neon_simple_env +def test_vm_bit_clear_on_heap_lock(neon_env_builder: NeonEnvBuilder): + tenant_conf = { + "checkpoint_distance": f"{128 * 1024}", + "compaction_target_size": f"{128 * 1024}", + "compaction_threshold": "1", + # create image layers eagerly, so that GC can remove some layers + "image_creation_threshold": "1", + # set PITR interval to be small, so we can do GC + "pitr_interval": "0 s", + } + env = neon_env_builder.init_start(initial_tenant_conf=tenant_conf) - env.neon_cli.create_branch("test_vm_bit_clear_on_heap_lock", "empty") + tenant_id = env.initial_tenant + timeline_id = env.neon_cli.create_branch("test_vm_bit_clear_on_heap_lock") endpoint = env.endpoints.create_start( "test_vm_bit_clear_on_heap_lock", config_lines=[ @@ -139,72 +148,88 @@ def test_vm_bit_clear_on_heap_lock(neon_simple_env: NeonEnv): # Install extension containing function needed for test cur.execute("CREATE EXTENSION neon_test_utils") - - cur.execute("SELECT pg_switch_wal()") + cur.execute("CREATE EXTENSION pageinspect") # Create a test table and freeze it to set the all-frozen VM bit on all pages. cur.execute("CREATE TABLE vmtest_lock (id integer PRIMARY KEY)") cur.execute("INSERT INTO vmtest_lock SELECT g FROM generate_series(1, 50000) g") - cur.execute("VACUUM FREEZE vmtest_lock") + + cur.execute("VACUUM (FREEZE, DISABLE_PAGE_SKIPPING true) vmtest_lock") # Lock a row. This clears the all-frozen VM bit for that page. + cur.execute("BEGIN") cur.execute("SELECT * FROM vmtest_lock WHERE id = 40000 FOR UPDATE") # Remember the XID. We will use it later to verify that we have consumed a lot of # XIDs after this. cur.execute("select pg_current_xact_id()") - locking_xid = cur.fetchall()[0][0] + locking_xid = int(cur.fetchall()[0][0]) - # Stop and restart postgres, to clear the buffer cache. + cur.execute("COMMIT") + + # The VM page in shared buffer cache, and the same page as reconstructed + # by the pageserver, should be equal. + cur.execute("select get_raw_page( 'vmtest_lock', 'vm', 0 )") + vm_page_in_cache = (cur.fetchall()[0][0])[:100].hex() + cur.execute("select get_raw_page_at_lsn( 'vmtest_lock', 'vm', 0, pg_current_wal_insert_lsn() )") + vm_page_at_pageserver = (cur.fetchall()[0][0])[:100].hex() + + assert vm_page_at_pageserver == vm_page_in_cache + + # The above assert is enough to verify the bug that was fixed in + # commit 66fa176cc8. But for good measure, we also reproduce the + # original problem that the missing VM page update caused. The + # rest of the test does that. + + # Kill and restart postgres, to clear the buffer cache. # # NOTE: clear_buffer_cache() will not do, because it evicts the dirty pages # in a "clean" way. Our neon extension will write a full-page image of the VM - # page, and we want to avoid that. - endpoint.stop() + # page, and we want to avoid that. A clean shutdown will also not do, for the + # same reason. + endpoint.stop(mode="immediate") + endpoint.start() pg_conn = endpoint.connect() cur = pg_conn.cursor() - cur.execute("select xmin, xmax, * from vmtest_lock where id = 40000 ") - tup = cur.fetchall() - xmax_before = tup[0][1] - # Consume a lot of XIDs, so that anti-wraparound autovacuum kicks # in and the clog gets truncated. We set autovacuum_freeze_max_age to a very # low value, so it doesn't take all that many XIDs for autovacuum to kick in. - for i in range(1000): + # + # We could use test_consume_xids() to consume XIDs much faster, + # but it wouldn't speed up the overall test, because we'd still + # need to wait for autovacuum to run. + for _ in range(1000): + cur.execute("select test_consume_xids(10000);") + for _ in range(1000): cur.execute( - """ - CREATE TEMP TABLE othertable (i int) ON COMMIT DROP; - do $$ - begin - for i in 1..100000 loop - -- Use a begin-exception block to generate a new subtransaction on each iteration - begin - insert into othertable values (i); - exception when others then - raise 'not expected %', sqlerrm; - end; - end loop; - end; - $$; - """ + "select get_raw_page_at_lsn( 'vmtest_lock', 'vm', 0, pg_current_wal_insert_lsn() )" ) - cur.execute("select xmin, xmax, * from vmtest_lock where id = 40000 ") - tup = cur.fetchall() - log.info(f"tuple = {tup}") - xmax = tup[0][1] - assert xmax == xmax_before + page = (cur.fetchall()[0][0])[:100].hex() + log.info(f"VM page contents: {page}") - if i % 50 == 0: - cur.execute("select datfrozenxid from pg_database where datname='postgres'") - datfrozenxid = cur.fetchall()[0][0] - if datfrozenxid > locking_xid: - break + cur.execute("select get_raw_page( 'vmtest_lock', 'vm', 0 )") + page = (cur.fetchall()[0][0])[:100].hex() + log.info(f"VM page contents in cache: {page}") + + cur.execute("select min(datfrozenxid::text::int) from pg_database") + datfrozenxid = int(cur.fetchall()[0][0]) + log.info(f"datfrozenxid {datfrozenxid} locking_xid: {locking_xid}") + if datfrozenxid > locking_xid + 3000000: + break + time.sleep(0.5) cur.execute("select pg_current_xact_id()") - curr_xid = cur.fetchall()[0][0] - assert int(curr_xid) - int(locking_xid) >= 100000 + curr_xid = int(cur.fetchall()[0][0]) + assert curr_xid - locking_xid >= 100000 + + # Perform GC in the pageserver. Otherwise the compute might still + # be able to download the already-deleted SLRU segment from the + # pageserver. That masks the original bug. + env.pageserver.http_client().timeline_checkpoint(tenant_id, timeline_id) + env.pageserver.http_client().timeline_compact(tenant_id, timeline_id) + env.pageserver.http_client().timeline_gc(tenant_id, timeline_id, 0) # Now, if the VM all-frozen bit was not correctly cleared on # replay, we will try to fetch the status of the XID that was @@ -214,3 +239,4 @@ def test_vm_bit_clear_on_heap_lock(neon_simple_env: NeonEnv): cur.execute("select xmin, xmax, * from vmtest_lock where id = 40000 for update") tup = cur.fetchall() log.info(f"tuple = {tup}") + cur.execute("commit transaction") From 89a5c654bfc688babcdfa6c9dcda68876c0d6f98 Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 9 Feb 2024 14:26:50 +0000 Subject: [PATCH 32/81] control_plane: follow up for embedded migrations (#6647) ## Problem In https://github.com/neondatabase/neon/pull/6637, we remove the need to run migrations externally, but for compat tests to work we can't remove those invocations from the neon_local binary. Once that previous PR merges, we can make the followup changes without upsetting compat tests. --- Cargo.lock | 4 - control_plane/Cargo.toml | 2 - control_plane/src/attachment_service.rs | 118 +++++------------------- workspace_hack/Cargo.toml | 2 - 4 files changed, 22 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0c319cd89..a2939e6c75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1329,8 +1329,6 @@ dependencies = [ "clap", "comfy-table", "compute_api", - "diesel", - "diesel_migrations", "futures", "git-version", "hex", @@ -6832,8 +6830,6 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", - "diesel", - "diesel_derives", "either", "fail", "futures-channel", diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 09c171f1d3..75e5dcb7f8 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -10,8 +10,6 @@ async-trait.workspace = true camino.workspace = true clap.workspace = true comfy-table.workspace = true -diesel = { version = "2.1.4", features = ["postgres"]} -diesel_migrations = { version = "2.1.0", features = ["postgres"]} futures.workspace = true git-version.workspace = true nix.workspace = true diff --git a/control_plane/src/attachment_service.rs b/control_plane/src/attachment_service.rs index c3e071aa71..14bfda47c3 100644 --- a/control_plane/src/attachment_service.rs +++ b/control_plane/src/attachment_service.rs @@ -1,11 +1,5 @@ use crate::{background_process, local_env::LocalEnv}; use camino::{Utf8Path, Utf8PathBuf}; -use diesel::{ - backend::Backend, - query_builder::{AstPass, QueryFragment, QueryId}, - Connection, PgConnection, QueryResult, RunQueryDsl, -}; -use diesel_migrations::{HarnessWithOutput, MigrationHarness}; use hyper::Method; use pageserver_api::{ models::{ @@ -17,7 +11,7 @@ use pageserver_api::{ use pageserver_client::mgmt_api::ResponseErrorMessageExt; use postgres_backend::AuthType; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::{env, str::FromStr}; +use std::str::FromStr; use tokio::process::Command; use tracing::instrument; use url::Url; @@ -273,37 +267,6 @@ impl AttachmentService { .expect("non-Unicode path") } - /// In order to access database migrations, we need to find the Neon source tree - async fn find_source_root(&self) -> anyhow::Result { - // We assume that either prd or our binary is in the source tree. The former is usually - // true for automated test runners, the latter is usually true for developer workstations. Often - // both are true, which is fine. - let candidate_start_points = [ - // Current working directory - Utf8PathBuf::from_path_buf(std::env::current_dir()?).unwrap(), - // Directory containing the binary we're running inside - Utf8PathBuf::from_path_buf(env::current_exe()?.parent().unwrap().to_owned()).unwrap(), - ]; - - // For each candidate start point, search through ancestors looking for a neon.git source tree root - for start_point in &candidate_start_points { - // Start from the build dir: assumes we are running out of a built neon source tree - for path in start_point.ancestors() { - // A crude approximation: the root of the source tree is whatever contains a "control_plane" - // subdirectory. - let control_plane = path.join("control_plane"); - if tokio::fs::try_exists(&control_plane).await? { - return Ok(path.to_owned()); - } - } - } - - // Fall-through - Err(anyhow::anyhow!( - "Could not find control_plane src dir, after searching ancestors of {candidate_start_points:?}" - )) - } - /// Find the directory containing postgres binaries, such as `initdb` and `pg_ctl` /// /// This usually uses ATTACHMENT_SERVICE_POSTGRES_VERSION of postgres, but will fall back @@ -343,69 +306,32 @@ impl AttachmentService { /// /// Returns the database url pub async fn setup_database(&self) -> anyhow::Result { - let database_url = format!( - "postgresql://localhost:{}/attachment_service", - self.postgres_port - ); - println!("Running attachment service database setup..."); - fn change_database_of_url(database_url: &str, default_database: &str) -> (String, String) { - let base = ::url::Url::parse(database_url).unwrap(); - let database = base.path_segments().unwrap().last().unwrap().to_owned(); - let mut new_url = base.join(default_database).unwrap(); - new_url.set_query(base.query()); - (database, new_url.into()) - } + const DB_NAME: &str = "attachment_service"; + let database_url = format!("postgresql://localhost:{}/{DB_NAME}", self.postgres_port); - #[derive(Debug, Clone)] - pub struct CreateDatabaseStatement { - db_name: String, - } + let pg_bin_dir = self.get_pg_bin_dir().await?; + let createdb_path = pg_bin_dir.join("createdb"); + let output = Command::new(&createdb_path) + .args([ + "-h", + "localhost", + "-p", + &format!("{}", self.postgres_port), + &DB_NAME, + ]) + .output() + .await + .expect("Failed to spawn createdb"); - impl CreateDatabaseStatement { - pub fn new(db_name: &str) -> Self { - CreateDatabaseStatement { - db_name: db_name.to_owned(), - } + if !output.status.success() { + let stderr = String::from_utf8(output.stderr).expect("Non-UTF8 output from createdb"); + if stderr.contains("already exists") { + tracing::info!("Database {DB_NAME} already exists"); + } else { + anyhow::bail!("createdb failed with status {}: {stderr}", output.status); } } - impl QueryFragment for CreateDatabaseStatement { - fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> { - out.push_sql("CREATE DATABASE "); - out.push_identifier(&self.db_name)?; - Ok(()) - } - } - - impl RunQueryDsl for CreateDatabaseStatement {} - - impl QueryId for CreateDatabaseStatement { - type QueryId = (); - - const HAS_STATIC_QUERY_ID: bool = false; - } - if PgConnection::establish(&database_url).is_err() { - let (database, postgres_url) = change_database_of_url(&database_url, "postgres"); - println!("Creating database: {database}"); - let mut conn = PgConnection::establish(&postgres_url)?; - CreateDatabaseStatement::new(&database).execute(&mut conn)?; - } - let mut conn = PgConnection::establish(&database_url)?; - - let migrations_dir = self - .find_source_root() - .await? - .join("control_plane/attachment_service/migrations"); - - let migrations = diesel_migrations::FileBasedMigrations::from_path(migrations_dir)?; - println!("Running migrations in {}", migrations.path().display()); - HarnessWithOutput::write_to_stdout(&mut conn) - .run_pending_migrations(migrations) - .map(|_| ()) - .map_err(|e| anyhow::anyhow!(e))?; - - println!("Migrations complete"); - Ok(database_url) } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 70b238913d..8e9cc43152 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -29,7 +29,6 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } -diesel = { version = "2", features = ["postgres", "r2d2", "serde_json"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures-channel = { version = "0.3", features = ["sink"] } @@ -90,7 +89,6 @@ anyhow = { version = "1", features = ["backtrace"] } bytes = { version = "1", features = ["serde"] } cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } -diesel_derives = { version = "2", features = ["32-column-tables", "postgres", "r2d2", "with-deprecated"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } From 96d89cde5108850d1f0f41c23ff175552297ab9d Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 9 Feb 2024 15:50:51 +0000 Subject: [PATCH 33/81] Proxy error reworking (#6453) ## Problem Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and doing a bit less radical changes. smaller commits. We currently don't report error classifications in proxy as the current error handling made it hard to do so. ## Summary of changes 1. Add a `ReportableError` trait that all errors will implement. This provides the error classification functionality. 2. Handle Client requests a strongly typed error * this error is a `ReportableError` and is logged appropriately 3. The handle client error only has a few possible error types, to account for the fact that at this point errors should be returned to the user. --- proxy/src/auth.rs | 37 ++++++++- proxy/src/auth/backend/classic.rs | 4 +- proxy/src/auth/backend/link.rs | 18 ++-- proxy/src/auth/credentials.rs | 14 +++- proxy/src/bin/pg_sni_router.rs | 11 ++- proxy/src/cancellation.rs | 37 +++++++-- proxy/src/compute.rs | 22 ++++- proxy/src/console/provider.rs | 31 ++++++- proxy/src/context.rs | 18 +++- proxy/src/context/parquet.rs | 2 +- proxy/src/error.rs | 38 +++++++-- proxy/src/metrics.rs | 19 +++++ proxy/src/proxy.rs | 95 ++++++++++++++++++---- proxy/src/proxy/handshake.rs | 76 +++++++++++++---- proxy/src/proxy/passthrough.rs | 23 ++++-- proxy/src/proxy/tests.rs | 8 +- proxy/src/proxy/tests/mitm.rs | 10 +-- proxy/src/sasl.rs | 14 +++- proxy/src/serverless.rs | 14 ++-- proxy/src/serverless/backend.rs | 29 +++++-- proxy/src/serverless/conn_pool.rs | 4 +- proxy/src/serverless/json.rs | 32 ++++++-- proxy/src/serverless/sql_over_http.rs | 113 ++++++++++++-------------- proxy/src/serverless/websocket.rs | 30 +++++-- proxy/src/stream.rs | 75 ++++++++++++++--- 25 files changed, 588 insertions(+), 186 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 8d1b861a66..48de4e2353 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -5,7 +5,8 @@ pub use backend::BackendType; mod credentials; pub use credentials::{ - check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern, + check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, + ComputeUserInfoParseError, IpPattern, }; mod password_hack; @@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload; mod flow; pub use flow::*; +use tokio::time::error::Elapsed; -use crate::{console, error::UserFacingError}; +use crate::{ + console, + error::{ReportableError, UserFacingError}, +}; use std::io; use thiserror::Error; @@ -67,6 +72,9 @@ pub enum AuthErrorImpl { #[error("Too many connections to this endpoint. Please try again later.")] TooManyConnections, + + #[error("Authentication timed out")] + UserTimeout(Elapsed), } #[derive(Debug, Error)] @@ -93,6 +101,10 @@ impl AuthError { pub fn is_auth_failed(&self) -> bool { matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) } + + pub fn user_timeout(elapsed: Elapsed) -> Self { + AuthErrorImpl::UserTimeout(elapsed).into() + } } impl> From for AuthError { @@ -116,6 +128,27 @@ impl UserFacingError for AuthError { Io(_) => "Internal error".to_string(), IpAddressNotAllowed => self.to_string(), TooManyConnections => self.to_string(), + UserTimeout(_) => self.to_string(), + } + } +} + +impl ReportableError for AuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + use AuthErrorImpl::*; + match self.0.as_ref() { + Link(e) => e.get_error_kind(), + GetAuthInfo(e) => e.get_error_kind(), + WakeCompute(e) => e.get_error_kind(), + Sasl(e) => e.get_error_kind(), + AuthFailed(_) => crate::error::ErrorKind::User, + BadAuthMethod(_) => crate::error::ErrorKind::User, + MalformedPassword(_) => crate::error::ErrorKind::User, + MissingEndpointName => crate::error::ErrorKind::User, + Io(_) => crate::error::ErrorKind::ClientDisconnect, + IpAddressNotAllowed => crate::error::ErrorKind::User, + TooManyConnections => crate::error::ErrorKind::RateLimit, + UserTimeout(_) => crate::error::ErrorKind::User, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 384063ceae..745dd75107 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -45,9 +45,9 @@ pub(super) async fn authenticate( } ) .await - .map_err(|error| { + .map_err(|e| { warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs()); - auth::io::Error::new(auth::io::ErrorKind::TimedOut, error) + auth::AuthError::user_timeout(e) })??; let client_key = match auth_outcome { diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index d8ae362c03..c71637dd1a 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -2,7 +2,7 @@ use crate::{ auth, compute, console::{self, provider::NodeInfo}, context::RequestMonitoring, - error::UserFacingError, + error::{ReportableError, UserFacingError}, stream::PqStream, waiters, }; @@ -14,10 +14,6 @@ use tracing::{info, info_span}; #[derive(Debug, Error)] pub enum LinkAuthError { - /// Authentication error reported by the console. - #[error("Authentication failed: {0}")] - AuthFailed(String), - #[error(transparent)] WaiterRegister(#[from] waiters::RegisterError), @@ -30,10 +26,16 @@ pub enum LinkAuthError { impl UserFacingError for LinkAuthError { fn to_string_client(&self) -> String { - use LinkAuthError::*; + "Internal error".to_string() + } +} + +impl ReportableError for LinkAuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - AuthFailed(_) => self.to_string(), - _ => "Internal error".to_string(), + LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service, + LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service, + LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect, } } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 875baaec47..d32609e44c 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,8 +1,12 @@ //! User credentials used in authentication. use crate::{ - auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI, + auth::password_hack::parse_endpoint_param, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, + proxy::NeonOptions, + serverless::SERVERLESS_DRIVER_SNI, EndpointId, RoleName, }; use itertools::Itertools; @@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError { impl UserFacingError for ComputeUserInfoParseError {} +impl ReportableError for ComputeUserInfoParseError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + crate::error::ErrorKind::User + } +} + /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 471be7af25..43b805e8a1 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -240,7 +240,9 @@ async fn ssl_handshake( ?unexpected, "unexpected startup packet, rejecting connection" ); - stream.throw_error_str(ERR_INSECURE_CONNECTION).await? + stream + .throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User) + .await? } } } @@ -272,5 +274,10 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await + + // doesn't yet matter as pg-sni-router doesn't report analytics logs + ctx.set_success(); + ctx.log(); + + proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d4ee657144..fe614628d8 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,24 +1,45 @@ -use anyhow::Context; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; +use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use crate::error::ReportableError; + /// Enables serving `CancelRequest`s. #[derive(Default)] pub struct CancelMap(DashMap>); +#[derive(Debug, Error)] +pub enum CancelError { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + Postgres(#[from] tokio_postgres::Error), +} + +impl ReportableError for CancelError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + CancelError::IO(_) => crate::error::ErrorKind::Compute, + CancelError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + CancelError::Postgres(_) => crate::error::ErrorKind::Compute, + } + } +} + impl CancelMap { /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> { + pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { // NB: we should immediately release the lock after cloning the token. - let cancel_closure = self - .0 - .get(&key) - .and_then(|x| x.clone()) - .with_context(|| format!("query cancellation key not found: {key}"))?; + let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + tracing::warn!("query cancellation key not found: {key}"); + return Ok(()); + }; info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await @@ -81,7 +102,7 @@ impl CancelClosure { } /// Cancels the query running on user's compute node. - pub async fn try_cancel_query(self) -> anyhow::Result<()> { + async fn try_cancel_query(self) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; self.cancel_token.cancel_query_raw(socket, NoTls).await?; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index aef1aab733..83940d80ec 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,6 +1,10 @@ use crate::{ - auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError, - context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE, + auth::parse_endpoint_param, + cancellation::CancelClosure, + console::errors::WakeComputeError, + context::RequestMonitoring, + error::{ReportableError, UserFacingError}, + metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::neon_option, }; use futures::{FutureExt, TryFutureExt}; @@ -58,6 +62,20 @@ impl UserFacingError for ConnectionError { } } +impl ReportableError for ConnectionError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ConnectionError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, + ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::WakeComputeError(e) => e.get_error_kind(), + } + } +} + /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index c53d929470..e5cad42753 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -20,7 +20,7 @@ use tracing::info; pub mod errors { use crate::{ - error::{io_error, UserFacingError}, + error::{io_error, ReportableError, UserFacingError}, http, proxy::retry::ShouldRetry, }; @@ -81,6 +81,15 @@ pub mod errors { } } + impl ReportableError for ApiError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane, + ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + impl ShouldRetry for ApiError { fn could_retry(&self) -> bool { match self { @@ -150,6 +159,16 @@ pub mod errors { } } } + + impl ReportableError for GetAuthInfoError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane, + GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane, + } + } + } + #[derive(Debug, Error)] pub enum WakeComputeError { #[error("Console responded with a malformed compute address: {0}")] @@ -194,6 +213,16 @@ pub mod errors { } } } + + impl ReportableError for WakeComputeError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, + WakeComputeError::ApiError(e) => e.get_error_kind(), + WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit, + } + } + } } /// Auth secret which is managed by the cloud. diff --git a/proxy/src/context.rs b/proxy/src/context.rs index fe204534b7..d2bf3f68d3 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -8,8 +8,10 @@ use tokio::sync::mpsc; use uuid::Uuid; use crate::{ - console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId, - EndpointId, ProjectId, RoleName, + console::messages::MetricsAuxInfo, + error::ErrorKind, + metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, + BranchId, EndpointId, ProjectId, RoleName, }; pub mod parquet; @@ -108,6 +110,18 @@ impl RequestMonitoring { self.user = Some(user); } + pub fn set_error_kind(&mut self, kind: ErrorKind) { + ERROR_BY_KIND + .with_label_values(&[kind.to_metric_label()]) + .inc(); + if let Some(ep) = &self.endpoint_id { + ENDPOINT_ERRORS_BY_KIND + .with_label_values(&[kind.to_metric_label()]) + .measure(ep); + } + self.error_kind = Some(kind); + } + pub fn set_success(&mut self) { self.success = true; } diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 8510c5c586..0fe46915bc 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -108,7 +108,7 @@ impl From for RequestData { branch: value.branch.as_deref().map(String::from), protocol: value.protocol, region: value.region, - error: value.error_kind.as_ref().map(|e| e.to_str()), + error: value.error_kind.as_ref().map(|e| e.to_metric_label()), success: value.success, duration_us: SystemTime::from(value.first_packet) .elapsed() diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 5b2dd7ecfd..eafe92bf48 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -17,7 +17,7 @@ pub fn log_error(e: E) -> E { /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: fmt::Display { +pub trait UserFacingError: ReportableError { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -29,13 +29,13 @@ pub trait UserFacingError: fmt::Display { } } -#[derive(Clone)] +#[derive(Copy, Clone, Debug)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, /// Network error between user and proxy. Not necessarily user error - Disconnect, + ClientDisconnect, /// Proxy self-imposed rate limits RateLimit, @@ -46,6 +46,9 @@ pub enum ErrorKind { /// Error communicating with control plane ControlPlane, + /// Postgres error + Postgres, + /// Error communicating with compute Compute, } @@ -54,11 +57,36 @@ impl ErrorKind { pub fn to_str(&self) -> &'static str { match self { ErrorKind::User => "request failed due to user error", - ErrorKind::Disconnect => "client disconnected", + ErrorKind::ClientDisconnect => "client disconnected", ErrorKind::RateLimit => "request cancelled due to rate limit", ErrorKind::Service => "internal service error", ErrorKind::ControlPlane => "non-retryable control plane error", - ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)", + ErrorKind::Postgres => "postgres error", + ErrorKind::Compute => { + "non-retryable compute connection error (or exhausted retry capacity)" + } + } + } + + pub fn to_metric_label(&self) -> &'static str { + match self { + ErrorKind::User => "user", + ErrorKind::ClientDisconnect => "clientdisconnect", + ErrorKind::RateLimit => "ratelimit", + ErrorKind::Service => "service", + ErrorKind::ControlPlane => "controlplane", + ErrorKind::Postgres => "postgres", + ErrorKind::Compute => "compute", } } } + +pub trait ReportableError: fmt::Display + Send + 'static { + fn get_error_kind(&self) -> ErrorKind; +} + +impl ReportableError for tokio::time::error::Elapsed { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::RateLimit + } +} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index e2d96a9c27..ccf89f9b05 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -274,3 +274,22 @@ pub static CONNECTING_ENDPOINTS: Lazy> = Lazy::new(|| { ) .unwrap() }); + +pub static ERROR_BY_KIND: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_errors_total", + "Number of errors by a given classification", + &["type"], + ) + .unwrap() +}); + +pub static ENDPOINT_ERRORS_BY_KIND: Lazy> = Lazy::new(|| { + register_hll_vec!( + 32, + "proxy_endpoints_affected_by_errors", + "Number of endpoints affected by errors of a given classification", + &["type"], + ) + .unwrap() +}); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index b3b221d3e2..50e22ec72a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -13,9 +13,10 @@ use crate::{ compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, + error::ReportableError, metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE}, protocol2::WithClientIp, - proxy::{handshake::handshake, passthrough::proxy_pass}, + proxy::handshake::{handshake, HandshakeData}, rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, EndpointCacheKey, @@ -28,14 +29,17 @@ use pq_proto::{BeMessage as Be, StartupMessageParams}; use regex::Regex; use smol_str::{format_smolstr, SmolStr}; use std::sync::Arc; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, Instrument}; -use self::connect_compute::{connect_to_compute, TcpMechanism}; +use self::{ + connect_compute::{connect_to_compute, TcpMechanism}, + passthrough::ProxyPassthrough, +}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; -const ERR_PROTO_VIOLATION: &str = "protocol violation"; pub async fn run_until_cancelled( f: F, @@ -98,14 +102,14 @@ pub async fn task_main( bail!("missing required client IP"); } - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); - socket .inner .set_nodelay(true) .context("failed to set socket option")?; - handle_client( + let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region); + + let res = handle_client( config, &mut ctx, cancel_map, @@ -113,7 +117,26 @@ pub async fn task_main( ClientMode::Tcp, endpoint_rate_limiter, ) - .await + .await; + + match res { + Err(e) => { + // todo: log and push to ctx the error kind + ctx.set_error_kind(e.get_error_kind()); + ctx.log(); + Err(e.into()) + } + Ok(None) => { + ctx.set_success(); + ctx.log(); + Ok(()) + } + Ok(Some(p)) => { + ctx.set_success(); + ctx.log(); + p.proxy_pass().await + } + } } .unwrap_or_else(move |e| { // Acknowledge that the task has finished with an error. @@ -169,6 +192,37 @@ impl ClientMode { } } +#[derive(Debug, Error)] +// almost all errors should be reported to the user, but there's a few cases where we cannot +// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons +// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, +// we cannot be sure the client even understands our error message +// 3. PrepareClient: The client disconnected, so we can't tell them anyway... +pub enum ClientRequestError { + #[error("{0}")] + Cancellation(#[from] cancellation::CancelError), + #[error("{0}")] + Handshake(#[from] handshake::HandshakeError), + #[error("{0}")] + HandshakeTimeout(#[from] tokio::time::error::Elapsed), + #[error("{0}")] + PrepareClient(#[from] std::io::Error), + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for ClientRequestError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + ClientRequestError::Cancellation(e) => e.get_error_kind(), + ClientRequestError::Handshake(e) => e.get_error_kind(), + ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit, + ClientRequestError::ReportedError(e) => e.get_error_kind(), + ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, @@ -176,7 +230,7 @@ pub async fn handle_client( stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, -) -> anyhow::Result<()> { +) -> Result>, ClientRequestError> { info!( protocol = ctx.protocol, "handling interactive connection from client" @@ -193,11 +247,16 @@ pub async fn handle_client( let tls = config.tls_config.as_ref(); let pause = ctx.latency_timer.pause(); - let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map); + let do_handshake = handshake(stream, mode.handshake_tls(tls)); let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { - Some(x) => x, - None => return Ok(()), // it's a cancellation request + HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Cancel(cancel_key_data) => { + return Ok(cancel_map + .cancel_session(cancel_key_data) + .await + .map(|()| None)?) + } }; drop(pause); @@ -222,7 +281,7 @@ pub async fn handle_client( if !endpoint_rate_limiter.check(ep) { return stream .throw_error(auth::AuthError::too_many_connections()) - .await; + .await?; } } @@ -242,7 +301,7 @@ pub async fn handle_client( let app = params.get("application_name"); let params_span = tracing::info_span!("", ?user, ?db, ?app); - return stream.throw_error(e).instrument(params_span).await; + return stream.throw_error(e).instrument(params_span).await?; } }; @@ -268,7 +327,13 @@ pub async fn handle_client( let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(ctx, stream, node.stream, aux).await + Ok(Some(ProxyPassthrough { + client: stream, + compute: node, + aux, + req: _request_gauge, + conn: _client_gauge, + })) } /// Finish client connection initialization: confirm auth success, send params, etc. @@ -277,7 +342,7 @@ async fn prepare_client_connection( node: &compute::PostgresConnection, session: &cancellation::Session, stream: &mut PqStream, -) -> anyhow::Result<()> { +) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. // The new token (cancel_key_data) will be sent to the client. let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 1ad8da20d7..4665e07d23 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,15 +1,60 @@ -use anyhow::{bail, Context}; -use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; +use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams}; +use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use crate::{ - cancellation::CancelMap, config::TlsConfig, - proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION}, - stream::{PqStream, Stream}, + error::ReportableError, + proxy::ERR_INSECURE_CONNECTION, + stream::{PqStream, Stream, StreamUpgradeError}, }; +#[derive(Error, Debug)] +pub enum HandshakeError { + #[error("data is sent before server replied with EncryptionResponse")] + EarlyData, + + #[error("protocol violation")] + ProtocolViolation, + + #[error("missing certificate")] + MissingCertificate, + + #[error("{0}")] + StreamUpgradeError(#[from] StreamUpgradeError), + + #[error("{0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + ReportedError(#[from] crate::stream::ReportedError), +} + +impl ReportableError for HandshakeError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + HandshakeError::EarlyData => crate::error::ErrorKind::User, + HandshakeError::ProtocolViolation => crate::error::ErrorKind::User, + // This error should not happen, but will if we have no default certificate and + // the client sends no SNI extension. + // If they provide SNI then we can be sure there is a certificate that matches. + HandshakeError::MissingCertificate => crate::error::ErrorKind::Service, + HandshakeError::StreamUpgradeError(upgrade) => match upgrade { + StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service, + StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + }, + HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + HandshakeError::ReportedError(e) => e.get_error_kind(), + } + } +} + +pub enum HandshakeData { + Startup(PqStream>, StartupMessageParams), + Cancel(CancelKeyData), +} + /// Establish a (most probably, secure) connection with the client. /// For better testing experience, `stream` can be any object satisfying the traits. /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; @@ -18,8 +63,7 @@ use crate::{ pub async fn handshake( stream: S, mut tls: Option<&TlsConfig>, - cancel_map: &CancelMap, -) -> anyhow::Result>, StartupMessageParams)>> { +) -> Result, HandshakeError> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); @@ -49,14 +93,14 @@ pub async fn handshake( // pipelining in our node js driver. We should probably // support that by chaining read_buf with the stream. if !read_buf.is_empty() { - bail!("data is sent before server replied with EncryptionResponse"); + return Err(HandshakeError::EarlyData); } let tls_stream = raw.upgrade(tls.to_server_config()).await?; let (_, tls_server_end_point) = tls .cert_resolver .resolve(tls_stream.get_ref().1.server_name()) - .context("missing certificate")?; + .ok_or(HandshakeError::MissingCertificate)?; stream = PqStream::new(Stream::Tls { tls: Box::new(tls_stream), @@ -64,7 +108,7 @@ pub async fn handshake( }); } } - _ => bail!(ERR_PROTO_VIOLATION), + _ => return Err(HandshakeError::ProtocolViolation), }, GssEncRequest => match stream.get_ref() { Stream::Raw { .. } if !tried_gss => { @@ -73,23 +117,23 @@ pub async fn handshake( // Currently, we don't support GSSAPI stream.write_message(&Be::EncryptionResponse(false)).await?; } - _ => bail!(ERR_PROTO_VIOLATION), + _ => return Err(HandshakeError::ProtocolViolation), }, StartupMessage { params, .. } => { // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { - stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; + return stream + .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User) + .await?; } info!(session_type = "normal", "successful handshake"); - break Ok(Some((stream, params))); + break Ok(HandshakeData::Startup(stream, params)); } CancelRequest(cancel_key_data) => { - cancel_map.cancel_session(cancel_key_data).await?; - info!(session_type = "cancellation", "successful handshake"); - break Ok(None); + break Ok(HandshakeData::Cancel(cancel_key_data)); } } } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 53e0c3c8f3..b7018c6fb5 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,9 +1,11 @@ use crate::{ + compute::PostgresConnection, console::messages::MetricsAuxInfo, - context::RequestMonitoring, metrics::NUM_BYTES_PROXIED_COUNTER, + stream::Stream, usage_metrics::{Ids, USAGE_METRICS}, }; +use metrics::IntCounterPairGuard; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; use utils::measured_stream::MeasuredStream; @@ -11,14 +13,10 @@ use utils::measured_stream::MeasuredStream; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] pub async fn proxy_pass( - ctx: &mut RequestMonitoring, client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, ) -> anyhow::Result<()> { - ctx.set_success(); - ctx.log(); - let usage = USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id.clone(), branch_id: aux.branch_id.clone(), @@ -51,3 +49,18 @@ pub async fn proxy_pass( Ok(()) } + +pub struct ProxyPassthrough { + pub client: Stream, + pub compute: PostgresConnection, + pub aux: MetricsAuxInfo, + + pub req: IntCounterPairGuard, + pub conn: IntCounterPairGuard, +} + +impl ProxyPassthrough { + pub async fn proxy_pass(self) -> anyhow::Result<()> { + proxy_pass(self.client, self.compute.stream, self.aux).await + } +} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 656cabac75..3e961afb41 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -163,11 +163,11 @@ async fn dummy_proxy( tls: Option, auth: impl TestAuth + Send, ) -> anyhow::Result<()> { - let cancel_map = CancelMap::default(); let client = WithClientIp::new(client); - let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map) - .await? - .context("handshake failed")?; + let mut stream = match handshake(client, tls.as_ref()).await? { + HandshakeData::Startup(stream, _) => stream, + HandshakeData::Cancel(_) => bail!("cancellation not supported"), + }; auth.authenticate(&mut stream).await?; diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index a0a84a1dc0..ed89e51754 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -35,12 +35,10 @@ async fn proxy_mitm( tokio::spawn(async move { // begin handshake with end_server let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await; - // process handshake with end_client - let (end_client, startup) = - handshake(client1, Some(&server_config1), &CancelMap::default()) - .await - .unwrap() - .unwrap(); + let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() { + HandshakeData::Startup(stream, params) => (stream, params), + HandshakeData::Cancel(_) => panic!("cancellation not supported"), + }; let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame); let (end_client, buf) = end_client.framed.into_inner(); diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index da1cf21c6a..1cf8b53e11 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -10,7 +10,7 @@ mod channel_binding; mod messages; mod stream; -use crate::error::UserFacingError; +use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; @@ -48,6 +48,18 @@ impl UserFacingError for Error { } } +impl ReportableError for Error { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, + Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, + Error::BadClientMessage(_) => crate::error::ErrorKind::User, + Error::MissingBinding => crate::error::ErrorKind::Service, + Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} + /// A convenient result type for SASL exchange. pub type Result = std::result::Result; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 58aa925a6a..a20600b94a 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -109,10 +109,9 @@ pub async fn task_main( let make_svc = hyper::service::make_service_fn( |stream: &tokio_rustls::server::TlsStream>| { - let (io, tls) = stream.get_ref(); + let (io, _) = stream.get_ref(); let client_addr = io.client_addr(); let remote_addr = io.inner.remote_addr(); - let sni_name = tls.server_name().map(|s| s.to_string()); let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -125,7 +124,6 @@ pub async fn task_main( }; Ok(MetricService::new(hyper::service::service_fn( move |req: Request| { - let sni_name = sni_name.clone(); let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -141,7 +139,6 @@ pub async fn task_main( ws_connections, cancel_map, session_id, - sni_name, peer_addr.ip(), endpoint_rate_limiter, ) @@ -210,7 +207,6 @@ async fn request_handler( ws_connections: TaskTracker, cancel_map: Arc, session_id: uuid::Uuid, - sni_hostname: Option, peer_addr: IpAddr, endpoint_rate_limiter: Arc, ) -> Result, ApiError> { @@ -230,11 +226,11 @@ async fn request_handler( ws_connections.spawn( async move { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); + let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region); if let Err(e) = websocket::serve_websocket( config, - &mut ctx, + ctx, websocket, cancel_map, host, @@ -251,9 +247,9 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); + let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); - sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await + sql_over_http::handle(config, ctx, request, backend).await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 466a74f0ea..03257e9161 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,6 +1,5 @@ use std::{sync::Arc, time::Duration}; -use anyhow::Context; use async_trait::async_trait; use tracing::info; @@ -8,7 +7,10 @@ use crate::{ auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, - console::CachedNodeInfo, + console::{ + errors::{GetAuthInfoError, WakeComputeError}, + CachedNodeInfo, + }, context::RequestMonitoring, proxy::connect_compute::ConnectMechanism, }; @@ -66,7 +68,7 @@ impl PoolingBackend { conn_info: ConnInfo, keys: ComputeCredentialKeys, force_new: bool, - ) -> anyhow::Result> { + ) -> Result, HttpConnError> { let maybe_client = if !force_new { info!("pool: looking for an existing connection"); self.pool.get(ctx, &conn_info).await? @@ -90,7 +92,7 @@ impl PoolingBackend { let mut node_info = backend .wake_compute(ctx) .await? - .context("missing cache entry from wake_compute")?; + .ok_or(HttpConnError::NoComputeInfo)?; match keys { #[cfg(any(test, feature = "testing"))] @@ -114,6 +116,23 @@ impl PoolingBackend { } } +#[derive(Debug, thiserror::Error)] +pub enum HttpConnError { + #[error("pooled connection closed at inconsistent state")] + ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connection to compute")] + ConnectionError(#[from] tokio_postgres::Error), + + #[error("could not get auth info")] + GetAuthInfo(#[from] GetAuthInfoError), + #[error("user not authenticated")] + AuthError(#[from] AuthError), + #[error("wake_compute returned error")] + WakeCompute(#[from] WakeComputeError), + #[error("wake_compute returned nothing")] + NoComputeInfo, +} + struct TokioMechanism { pool: Arc>, conn_info: ConnInfo, @@ -124,7 +143,7 @@ struct TokioMechanism { impl ConnectMechanism for TokioMechanism { type Connection = Client; type ConnectError = tokio_postgres::Error; - type Error = anyhow::Error; + type Error = HttpConnError; async fn connect_once( &self, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index a7b2c532d2..f92793096b 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -28,6 +28,8 @@ use crate::{ use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; +use super::backend::HttpConnError; + pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); #[derive(Debug, Clone)] @@ -358,7 +360,7 @@ impl GlobalConnPool { self: &Arc, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> anyhow::Result>> { + ) -> Result>, HttpConnError> { let mut client: Option> = None; let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index a089d34040..c22c63e85b 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -60,6 +60,20 @@ fn json_array_to_pg_array(value: &Value) -> Option { } } +#[derive(Debug, thiserror::Error)] +pub enum JsonConversionError { + #[error("internal error compute returned invalid data: {0}")] + AsTextError(tokio_postgres::Error), + #[error("parse int error: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("parse float error: {0}")] + ParseFloatError(#[from] std::num::ParseFloatError), + #[error("parse json error: {0}")] + ParseJsonError(#[from] serde_json::Error), + #[error("unbalanced array")] + UnbalancedArray, +} + // // Convert postgres row with text-encoded values to JSON object // @@ -68,7 +82,7 @@ pub fn pg_text_row_to_json( columns: &[Type], raw_output: bool, array_mode: bool, -) -> Result { +) -> Result { let iter = row .columns() .iter() @@ -76,7 +90,7 @@ pub fn pg_text_row_to_json( .enumerate() .map(|(i, (column, typ))| { let name = column.name(); - let pg_value = row.as_text(i)?; + let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?; let json_value = if raw_output { match pg_value { Some(v) => Value::String(v.to_string()), @@ -92,10 +106,10 @@ pub fn pg_text_row_to_json( // drop keys and aggregate into array let arr = iter .map(|r| r.map(|(_key, val)| val)) - .collect::, anyhow::Error>>()?; + .collect::, JsonConversionError>>()?; Ok(Value::Array(arr)) } else { - let obj = iter.collect::, anyhow::Error>>()?; + let obj = iter.collect::, JsonConversionError>>()?; Ok(Value::Object(obj)) } } @@ -103,7 +117,7 @@ pub fn pg_text_row_to_json( // // Convert postgres text-encoded value to JSON value // -fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { +fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result { if let Some(val) = pg_value { if let Kind::Array(elem_type) = pg_type.kind() { return pg_array_parse(val, elem_type); @@ -142,7 +156,7 @@ fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result Result { +fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result { _pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v) } @@ -150,7 +164,7 @@ fn _pg_array_parse( pg_array: &str, elem_type: &Type, nested: bool, -) -> Result<(Value, usize), anyhow::Error> { +) -> Result<(Value, usize), JsonConversionError> { let mut pg_array_chr = pg_array.char_indices(); let mut level = 0; let mut quote = false; @@ -170,7 +184,7 @@ fn _pg_array_parse( entry: &mut String, entries: &mut Vec, elem_type: &Type, - ) -> Result<(), anyhow::Error> { + ) -> Result<(), JsonConversionError> { if !entry.is_empty() { // While in usual postgres response we get nulls as None and everything else // as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while @@ -234,7 +248,7 @@ fn _pg_array_parse( } if level != 0 { - return Err(anyhow::anyhow!("unbalanced array")); + return Err(JsonConversionError::UnbalancedArray); } Ok((Value::Array(entries), 0)) diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 25e8813625..401022347e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use anyhow::bail; -use anyhow::Context; use futures::pin_mut; use futures::StreamExt; use hyper::body::HttpBody; @@ -29,9 +28,11 @@ use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; +use crate::auth::ComputeUserInfoParseError; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; +use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; @@ -41,7 +42,6 @@ use super::backend::PoolingBackend; use super::conn_pool::ConnInfo; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; -use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -86,67 +86,70 @@ where Ok(json_to_pg_text(json)) } +#[derive(Debug, thiserror::Error)] +pub enum ConnInfoError { + #[error("invalid header: {0}")] + InvalidHeader(&'static str), + #[error("invalid connection string: {0}")] + UrlParseError(#[from] url::ParseError), + #[error("incorrect scheme")] + IncorrectScheme, + #[error("missing database name")] + MissingDbName, + #[error("invalid database name")] + InvalidDbName, + #[error("missing username")] + MissingUsername, + #[error("missing password")] + MissingPassword, + #[error("missing hostname")] + MissingHostname, + #[error("invalid hostname: {0}")] + InvalidEndpoint(#[from] ComputeUserInfoParseError), + #[error("malformed endpoint")] + MalformedEndpoint, +} + fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, - sni_hostname: Option, tls: &TlsConfig, -) -> Result { +) -> Result { let connection_string = headers .get("Neon-Connection-String") - .ok_or(anyhow::anyhow!("missing connection string"))? - .to_str()?; + .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? + .to_str() + .map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?; let connection_url = Url::parse(connection_string)?; let protocol = connection_url.scheme(); if protocol != "postgres" && protocol != "postgresql" { - return Err(anyhow::anyhow!( - "connection string must start with postgres: or postgresql:" - )); + return Err(ConnInfoError::IncorrectScheme); } let mut url_path = connection_url .path_segments() - .ok_or(anyhow::anyhow!("missing database name"))?; + .ok_or(ConnInfoError::MissingDbName)?; - let dbname = url_path - .next() - .ok_or(anyhow::anyhow!("invalid database name"))?; + let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; let username = RoleName::from(connection_url.username()); if username.is_empty() { - return Err(anyhow::anyhow!("missing username")); + return Err(ConnInfoError::MissingUsername); } ctx.set_user(username.clone()); let password = connection_url .password() - .ok_or(anyhow::anyhow!("no password"))?; - - // TLS certificate selector now based on SNI hostname, so if we are running here - // we are sure that SNI hostname is set to one of the configured domain names. - let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?; + .ok_or(ConnInfoError::MissingPassword)?; let hostname = connection_url .host_str() - .ok_or(anyhow::anyhow!("no host"))?; + .ok_or(ConnInfoError::MissingHostname)?; - let host_header = headers - .get("host") - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(':').next()); - - // sni_hostname has to be either the same as hostname or the one used in serverless driver. - if !check_matches(&sni_hostname, hostname)? { - return Err(anyhow::anyhow!("mismatched SNI hostname and hostname")); - } else if let Some(h) = host_header { - if h != sni_hostname { - return Err(anyhow::anyhow!("mismatched host header and hostname")); - } - } - - let endpoint = endpoint_sni(hostname, &tls.common_names)?.context("malformed endpoint")?; + let endpoint = + endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?; ctx.set_endpoint_id(endpoint.clone()); let pairs = connection_url.query_pairs(); @@ -173,36 +176,27 @@ fn get_conn_info( }) } -fn check_matches(sni_hostname: &str, hostname: &str) -> Result { - if sni_hostname == hostname { - return Ok(true); - } - let (sni_hostname_first, sni_hostname_rest) = sni_hostname - .split_once('.') - .ok_or_else(|| anyhow::anyhow!("Unexpected sni format."))?; - let (_, hostname_rest) = hostname - .split_once('.') - .ok_or_else(|| anyhow::anyhow!("Unexpected hostname format."))?; - Ok(sni_hostname_rest == hostname_rest && sni_hostname_first == SERVERLESS_DRIVER_SNI) -} - // TODO: return different http error codes pub async fn handle( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + mut ctx: RequestMonitoring, request: Request, - sni_hostname: Option, backend: Arc, ) -> Result, ApiError> { let result = tokio::time::timeout( config.http_config.request_timeout, - handle_inner(config, ctx, request, sni_hostname, backend), + handle_inner(config, &mut ctx, request, backend), ) .await; let mut response = match result { Ok(r) => match r { - Ok(r) => r, + Ok(r) => { + ctx.set_success(); + r + } Err(e) => { + // TODO: ctx.set_error_kind(e.get_error_type()); + let mut message = format!("{:?}", e); let db_error = e .downcast_ref::() @@ -278,7 +272,9 @@ pub async fn handle( )? } }, - Err(_) => { + Err(e) => { + ctx.set_error_kind(e.get_error_kind()); + let message = format!( "HTTP-Connection timed out, execution time exeeded {} seconds", config.http_config.request_timeout.as_secs() @@ -290,6 +286,7 @@ pub async fn handle( )? } }; + response.headers_mut().insert( "Access-Control-Allow-Origin", hyper::http::HeaderValue::from_static("*"), @@ -302,7 +299,6 @@ async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, request: Request, - sni_hostname: Option, backend: Arc, ) -> anyhow::Result> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE @@ -318,12 +314,7 @@ async fn handle_inner( // let headers = request.headers(); // TLS config should be there. - let conn_info = get_conn_info( - ctx, - headers, - sni_hostname, - config.tls_config.as_ref().unwrap(), - )?; + let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?; info!( user = conn_info.user_info.user.as_str(), project = conn_info.user_info.endpoint.as_str(), @@ -487,8 +478,6 @@ async fn handle_inner( } }; - ctx.set_success(); - ctx.log(); let metrics = client.metrics(); // how could this possibly fail diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index f68b35010a..062dd440b2 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -2,7 +2,7 @@ use crate::{ cancellation::CancelMap, config::ProxyConfig, context::RequestMonitoring, - error::io_error, + error::{io_error, ReportableError}, proxy::{handle_client, ClientMode}, rate_limiter::EndpointRateLimiter, }; @@ -131,23 +131,41 @@ impl AsyncBufRead for WebSocketRw { pub async fn serve_websocket( config: &'static ProxyConfig, - ctx: &mut RequestMonitoring, + mut ctx: RequestMonitoring, websocket: HyperWebsocket, cancel_map: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; - handle_client( + let res = handle_client( config, - ctx, + &mut ctx, cancel_map, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, ) - .await?; - Ok(()) + .await; + + match res { + Err(e) => { + // todo: log and push to ctx the error kind + ctx.set_error_kind(e.get_error_kind()); + ctx.log(); + Err(e.into()) + } + Ok(None) => { + ctx.set_success(); + ctx.log(); + Ok(()) + } + Ok(Some(p)) => { + ctx.set_success(); + ctx.log(); + p.proxy_pass().await + } + } } #[cfg(test)] diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index f48b3fe39f..0d639d2c07 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -1,6 +1,5 @@ use crate::config::TlsServerEndPoint; -use crate::error::UserFacingError; -use anyhow::bail; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; @@ -73,6 +72,30 @@ impl PqStream { } } +#[derive(Debug)] +pub struct ReportedError { + source: anyhow::Error, + error_kind: ErrorKind, +} + +impl std::fmt::Display for ReportedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.source.fmt(f) + } +} + +impl std::error::Error for ReportedError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.source() + } +} + +impl ReportableError for ReportedError { + fn get_error_kind(&self) -> ErrorKind { + self.error_kind + } +} + impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { @@ -98,24 +121,52 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Allowing string literals is safe under the assumption they might not contain any runtime info. /// This method exists due to `&str` not implementing `Into`. - pub async fn throw_error_str(&mut self, error: &'static str) -> anyhow::Result { - tracing::info!("forwarding error to user: {error}"); - self.write_message(&BeMessage::ErrorResponse(error, None)) - .await?; - bail!(error) + pub async fn throw_error_str( + &mut self, + msg: &'static str, + error_kind: ErrorKind, + ) -> Result { + tracing::info!( + kind = error_kind.to_metric_label(), + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + let _: Result<_, std::io::Error> = self + .write_message(&BeMessage::ErrorResponse(msg, None)) + .await; + + Err(ReportedError { + source: anyhow::anyhow!(msg), + error_kind, + }) } /// Write the error message using [`Self::write_message`], then re-throw it. /// Trait [`UserFacingError`] acts as an allowlist for error types. - pub async fn throw_error(&mut self, error: E) -> anyhow::Result + pub async fn throw_error(&mut self, error: E) -> Result where E: UserFacingError + Into, { + let error_kind = error.get_error_kind(); let msg = error.to_string_client(); - tracing::info!("forwarding error to user: {msg}"); - self.write_message(&BeMessage::ErrorResponse(&msg, None)) - .await?; - bail!(error) + tracing::info!( + kind=error_kind.to_metric_label(), + error=%error, + msg, + "forwarding error to user" + ); + + // already error case, ignore client IO error + let _: Result<_, std::io::Error> = self + .write_message(&BeMessage::ErrorResponse(&msg, None)) + .await; + + Err(ReportedError { + source: anyhow::anyhow!(error), + error_kind, + }) } } From 1bb9abebf2cc380fa5ef0b876280afd2d120c257 Mon Sep 17 00:00:00 2001 From: Arseny Sher Date: Fri, 9 Feb 2024 16:41:43 +0300 Subject: [PATCH 34/81] Remove WAL segments from s3 in batches. Do list-delete operations in batches instead of doing full list first, to ensure deletion makes progress even if there are a lot of files to remove. To this end, add max_keys limit to remote storage list_files. --- libs/remote_storage/src/azure_blob.rs | 16 +++++++- libs/remote_storage/src/lib.rs | 38 +++++++++++++------ libs/remote_storage/src/local_fs.rs | 13 +++++-- libs/remote_storage/src/s3_bucket.rs | 21 +++++++++- libs/remote_storage/src/simulate_failures.rs | 7 +++- libs/remote_storage/tests/common/tests.rs | 15 ++++++-- libs/remote_storage/tests/test_real_s3.rs | 2 +- .../src/tenant/remote_timeline_client.rs | 2 +- .../tenant/remote_timeline_client/download.rs | 4 +- safekeeper/src/wal_backup.rs | 29 ++++++++++++-- 10 files changed, 119 insertions(+), 28 deletions(-) diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index c6d5224706..df6d45dde1 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -191,6 +191,7 @@ impl RemoteStorage for AzureBlobStorage { &self, prefix: Option<&RemotePath>, mode: ListingMode, + max_keys: Option, ) -> anyhow::Result { // get the passed prefix or if it is not set use prefix_in_bucket value let list_prefix = prefix @@ -223,6 +224,8 @@ impl RemoteStorage for AzureBlobStorage { let mut response = builder.into_stream(); let mut res = Listing::default(); + // NonZeroU32 doesn't support subtraction apparently + let mut max_keys = max_keys.map(|mk| mk.get()); while let Some(l) = response.next().await { let entry = l.map_err(to_download_error)?; let prefix_iter = entry @@ -235,7 +238,18 @@ impl RemoteStorage for AzureBlobStorage { .blobs .blobs() .map(|k| self.name_to_relative_path(&k.name)); - res.keys.extend(blob_iter); + + for key in blob_iter { + res.keys.push(key); + if let Some(mut mk) = max_keys { + assert!(mk > 0); + mk -= 1; + if mk == 0 { + return Ok(res); // limit reached + } + max_keys = Some(mk); + } + } } Ok(res) } diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index b6648931ac..5a0b74e406 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -16,7 +16,12 @@ mod simulate_failures; mod support; use std::{ - collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime, + collections::HashMap, + fmt::Debug, + num::{NonZeroU32, NonZeroUsize}, + pin::Pin, + sync::Arc, + time::SystemTime, }; use anyhow::{bail, Context}; @@ -155,7 +160,7 @@ pub trait RemoteStorage: Send + Sync + 'static { prefix: Option<&RemotePath>, ) -> Result, DownloadError> { let result = self - .list(prefix, ListingMode::WithDelimiter) + .list(prefix, ListingMode::WithDelimiter, None) .await? .prefixes; Ok(result) @@ -171,11 +176,17 @@ pub trait RemoteStorage: Send + Sync + 'static { /// whereas, /// list_prefixes("foo/bar/") = ["cat", "dog"] /// See `test_real_s3.rs` for more details. + /// + /// max_keys limits max number of keys returned; None means unlimited. async fn list_files( &self, prefix: Option<&RemotePath>, + max_keys: Option, ) -> Result, DownloadError> { - let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys; + let result = self + .list(prefix, ListingMode::NoDelimiter, max_keys) + .await? + .keys; Ok(result) } @@ -183,6 +194,7 @@ pub trait RemoteStorage: Send + Sync + 'static { &self, prefix: Option<&RemotePath>, _mode: ListingMode, + max_keys: Option, ) -> Result; /// Streams the local file contents into remote into the remote storage entry. @@ -341,27 +353,31 @@ impl GenericRemoteStorage> { &self, prefix: Option<&RemotePath>, mode: ListingMode, + max_keys: Option, ) -> anyhow::Result { match self { - Self::LocalFs(s) => s.list(prefix, mode).await, - Self::AwsS3(s) => s.list(prefix, mode).await, - Self::AzureBlob(s) => s.list(prefix, mode).await, - Self::Unreliable(s) => s.list(prefix, mode).await, + Self::LocalFs(s) => s.list(prefix, mode, max_keys).await, + Self::AwsS3(s) => s.list(prefix, mode, max_keys).await, + Self::AzureBlob(s) => s.list(prefix, mode, max_keys).await, + Self::Unreliable(s) => s.list(prefix, mode, max_keys).await, } } // A function for listing all the files in a "directory" // Example: // list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"] + // + // max_keys limits max number of keys returned; None means unlimited. pub async fn list_files( &self, folder: Option<&RemotePath>, + max_keys: Option, ) -> Result, DownloadError> { match self { - Self::LocalFs(s) => s.list_files(folder).await, - Self::AwsS3(s) => s.list_files(folder).await, - Self::AzureBlob(s) => s.list_files(folder).await, - Self::Unreliable(s) => s.list_files(folder).await, + Self::LocalFs(s) => s.list_files(folder, max_keys).await, + Self::AwsS3(s) => s.list_files(folder, max_keys).await, + Self::AzureBlob(s) => s.list_files(folder, max_keys).await, + Self::Unreliable(s) => s.list_files(folder, max_keys).await, } } diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 3ebea76181..f53ba9db07 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -4,7 +4,9 @@ //! This storage used in tests, but can also be used in cases when a certain persistent //! volume is mounted to the local FS. -use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin, time::SystemTime}; +use std::{ + borrow::Cow, future::Future, io::ErrorKind, num::NonZeroU32, pin::Pin, time::SystemTime, +}; use anyhow::{bail, ensure, Context}; use bytes::Bytes; @@ -162,6 +164,7 @@ impl RemoteStorage for LocalFs { &self, prefix: Option<&RemotePath>, mode: ListingMode, + max_keys: Option, ) -> Result { let mut result = Listing::default(); @@ -178,6 +181,9 @@ impl RemoteStorage for LocalFs { !path.is_dir() }) .collect(); + if let Some(max_keys) = max_keys { + result.keys.truncate(max_keys.get() as usize); + } return Ok(result); } @@ -790,12 +796,12 @@ mod fs_tests { let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?; let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?; - let listing = storage.list(None, ListingMode::NoDelimiter).await?; + let listing = storage.list(None, ListingMode::NoDelimiter, None).await?; assert!(listing.prefixes.is_empty()); assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec()); // Delimiter: should only go one deep - let listing = storage.list(None, ListingMode::WithDelimiter).await?; + let listing = storage.list(None, ListingMode::WithDelimiter, None).await?; assert_eq!( listing.prefixes, @@ -808,6 +814,7 @@ mod fs_tests { .list( Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()), ListingMode::WithDelimiter, + None, ) .await?; assert_eq!( diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index 2b33a6ffd1..dee5750cac 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -7,6 +7,7 @@ use std::{ borrow::Cow, collections::HashMap, + num::NonZeroU32, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -408,8 +409,11 @@ impl RemoteStorage for S3Bucket { &self, prefix: Option<&RemotePath>, mode: ListingMode, + max_keys: Option, ) -> Result { let kind = RequestKind::List; + // s3 sdk wants i32 + let mut max_keys = max_keys.map(|mk| mk.get() as i32); let mut result = Listing::default(); // get the passed prefix or if it is not set use prefix_in_bucket value @@ -433,13 +437,20 @@ impl RemoteStorage for S3Bucket { let _guard = self.permit(kind).await; let started_at = start_measuring_requests(kind); + // min of two Options, returning Some if one is value and another is + // None (None is smaller than anything, so plain min doesn't work). + let request_max_keys = self + .max_keys_per_list_response + .into_iter() + .chain(max_keys.into_iter()) + .min(); let mut request = self .client .list_objects_v2() .bucket(self.bucket_name.clone()) .set_prefix(list_prefix.clone()) .set_continuation_token(continuation_token) - .set_max_keys(self.max_keys_per_list_response); + .set_max_keys(request_max_keys); if let ListingMode::WithDelimiter = mode { request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()); @@ -469,6 +480,14 @@ impl RemoteStorage for S3Bucket { let object_path = object.key().expect("response does not contain a key"); let remote_path = self.s3_object_to_relative_path(object_path); result.keys.push(remote_path); + if let Some(mut mk) = max_keys { + assert!(mk > 0); + mk -= 1; + if mk == 0 { + return Ok(result); // limit reached + } + max_keys = Some(mk); + } } result.prefixes.extend( diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 14bdb5ed4d..3dfa16b64e 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -4,6 +4,7 @@ use bytes::Bytes; use futures::stream::Stream; use std::collections::HashMap; +use std::num::NonZeroU32; use std::sync::Mutex; use std::time::SystemTime; use std::{collections::hash_map::Entry, sync::Arc}; @@ -113,20 +114,22 @@ impl RemoteStorage for UnreliableWrapper { async fn list_files( &self, folder: Option<&RemotePath>, + max_keys: Option, ) -> Result, DownloadError> { self.attempt(RemoteOp::ListPrefixes(folder.cloned())) .map_err(DownloadError::Other)?; - self.inner.list_files(folder).await + self.inner.list_files(folder, max_keys).await } async fn list( &self, prefix: Option<&RemotePath>, mode: ListingMode, + max_keys: Option, ) -> Result { self.attempt(RemoteOp::ListPrefixes(prefix.cloned())) .map_err(DownloadError::Other)?; - self.inner.list(prefix, mode).await + self.inner.list(prefix, mode, max_keys).await } async fn upload( diff --git a/libs/remote_storage/tests/common/tests.rs b/libs/remote_storage/tests/common/tests.rs index abccc24c97..6d062f3898 100644 --- a/libs/remote_storage/tests/common/tests.rs +++ b/libs/remote_storage/tests/common/tests.rs @@ -1,8 +1,8 @@ use anyhow::Context; use camino::Utf8Path; use remote_storage::RemotePath; -use std::collections::HashSet; use std::sync::Arc; +use std::{collections::HashSet, num::NonZeroU32}; use test_context::test_context; use tracing::debug; @@ -103,7 +103,7 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a let base_prefix = RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?; let root_files = test_client - .list_files(None) + .list_files(None, None) .await .context("client list root files failure")? .into_iter() @@ -113,8 +113,17 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a ctx.remote_blobs.clone(), "remote storage list_files on root mismatches with the uploads." ); + + // Test that max_keys limit works. In total there are about 21 files (see + // upload_simple_remote_data call in test_real_s3.rs). + let limited_root_files = test_client + .list_files(None, Some(NonZeroU32::new(2).unwrap())) + .await + .context("client list root files failure")?; + assert_eq!(limited_root_files.len(), 2); + let nested_remote_files = test_client - .list_files(Some(&base_prefix)) + .list_files(Some(&base_prefix), None) .await .context("client list nested files failure")? .into_iter() diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index fc52dabc36..3dc8347c83 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -70,7 +70,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: } async fn list_files(client: &Arc) -> anyhow::Result> { - Ok(retry(|| client.list_files(None)) + Ok(retry(|| client.list_files(None, None)) .await .context("list root files failure")? .into_iter() diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 0c7dd68c3f..e17dea01a8 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1151,7 +1151,7 @@ impl RemoteTimelineClient { let remaining = download_retry( || async { self.storage_impl - .list_files(Some(&timeline_storage_path)) + .list_files(Some(&timeline_storage_path), None) .await }, "list remaining files", diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index 33287fc8f4..e755cd08f3 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -220,7 +220,7 @@ pub async fn list_remote_timelines( || { download_cancellable( &cancel, - storage.list(Some(&remote_path), ListingMode::WithDelimiter), + storage.list(Some(&remote_path), ListingMode::WithDelimiter, None), ) }, &format!("list timelines for {tenant_shard_id}"), @@ -373,7 +373,7 @@ pub(super) async fn download_index_part( let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none()); let indices = download_retry( - || async { storage.list_files(Some(&index_prefix)).await }, + || async { storage.list_files(Some(&index_prefix), None).await }, "list index_part files", cancel, ) diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index df99244770..dbdc742d26 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -10,6 +10,7 @@ use utils::id::NodeId; use std::cmp::min; use std::collections::{HashMap, HashSet}; +use std::num::NonZeroU32; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -546,6 +547,10 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> { let ttid_path = Utf8Path::new(&ttid.tenant_id.to_string()).join(ttid.timeline_id.to_string()); let remote_path = RemotePath::new(&ttid_path)?; + // see DEFAULT_MAX_KEYS_PER_LIST_RESPONSE + // const Option unwrap is not stable, otherwise it would be const. + let batch_size: NonZeroU32 = NonZeroU32::new(1000).unwrap(); + // A backoff::retry is used here for two reasons: // - To provide a backoff rather than busy-polling the API on errors // - To absorb transient 429/503 conditions without hitting our error @@ -557,8 +562,26 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> { let token = CancellationToken::new(); // not really used backoff::retry( || async { - let files = storage.list_files(Some(&remote_path)).await?; - storage.delete_objects(&files).await + // Do list-delete in batch_size batches to make progress even if there a lot of files. + // Alternatively we could make list_files return iterator, but it is more complicated and + // I'm not sure deleting while iterating is expected in s3. + loop { + let files = storage + .list_files(Some(&remote_path), Some(batch_size)) + .await?; + if files.is_empty() { + return Ok(()); // done + } + // (at least) s3 results are sorted, so can log min/max: + // "List results are always returned in UTF-8 binary order." + info!( + "deleting batch of {} WAL segments [{}-{}]", + files.len(), + files.first().unwrap().object_name().unwrap_or(""), + files.last().unwrap().object_name().unwrap_or("") + ); + storage.delete_objects(&files).await?; + } }, |_| false, 3, @@ -594,7 +617,7 @@ pub async fn copy_s3_segments( let remote_path = RemotePath::new(&relative_dst_path)?; - let files = storage.list_files(Some(&remote_path)).await?; + let files = storage.list_files(Some(&remote_path), None).await?; let uploaded_segments = &files .iter() .filter_map(|file| file.object_name().map(ToOwned::to_owned)) From ca818c8bd76d815f0d41eb61fdb8fb9b826ffe54 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 9 Feb 2024 20:09:37 +0100 Subject: [PATCH 35/81] fix(test_ondemand_download_timetravel): occasionally fails with slightly higher physical size (#6687) --- test_runner/regress/test_ondemand_download.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test_runner/regress/test_ondemand_download.py b/test_runner/regress/test_ondemand_download.py index af2d7aae88..3a197875dd 100644 --- a/test_runner/regress/test_ondemand_download.py +++ b/test_runner/regress/test_ondemand_download.py @@ -197,6 +197,14 @@ def test_ondemand_download_timetravel(neon_env_builder: NeonEnvBuilder): ##### Stop the first pageserver instance, erase all its data env.endpoints.stop_all() + # Stop safekeepers and take another checkpoint. The endpoints might + # have written a few more bytes during shutdown. + for sk in env.safekeepers: + sk.stop() + + client.timeline_checkpoint(tenant_id, timeline_id) + current_lsn = Lsn(client.timeline_detail(tenant_id, timeline_id)["last_record_lsn"]) + # wait until pageserver has successfully uploaded all the data to remote storage wait_for_upload(client, tenant_id, timeline_id, current_lsn) From cbd3a32d4d4275338c851dd158e0cb950d64ee91 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 9 Feb 2024 19:22:23 +0000 Subject: [PATCH 36/81] proxy: decode username and password (#6700) ## Problem usernames and passwords can be URL 'percent' encoded in the connection string URL provided by serverless driver. ## Summary of changes Decode the parameters when getting conn info --- Cargo.lock | 2 ++ Cargo.toml | 1 + proxy/Cargo.toml | 4 +++- proxy/src/serverless/backend.rs | 2 +- proxy/src/serverless/conn_pool.rs | 7 ++++--- proxy/src/serverless/sql_over_http.rs | 10 ++++++++-- test_runner/fixtures/neon_fixtures.py | 6 +++--- test_runner/regress/test_proxy.py | 12 ++++++++++++ 8 files changed, 34 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a2939e6c75..83afdaf66f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4125,6 +4125,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "smallvec", "smol_str", "socket2 0.5.5", "sync_wrapper", @@ -4143,6 +4144,7 @@ dependencies = [ "tracing-subscriber", "tracing-utils", "url", + "urlencoding", "utils", "uuid", "walkdir", diff --git a/Cargo.toml b/Cargo.toml index 6a2c3fa563..ebc3dfa7b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -171,6 +171,7 @@ tracing-opentelemetry = "0.20.0" tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } twox-hash = { version = "1.6.3", default-features = false } url = "2.2" +urlencoding = "2.1" uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" webpki-roots = "0.25" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 83cab381b3..0777d361d2 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -60,6 +60,8 @@ scopeguard.workspace = true serde.workspace = true serde_json.workspace = true sha2.workspace = true +smol_str.workspace = true +smallvec.workspace = true socket2.workspace = true sync_wrapper.workspace = true task-local-extensions.workspace = true @@ -76,6 +78,7 @@ tracing-subscriber.workspace = true tracing-utils.workspace = true tracing.workspace = true url.workspace = true +urlencoding.workspace = true utils.workspace = true uuid.workspace = true webpki-roots.workspace = true @@ -84,7 +87,6 @@ native-tls.workspace = true postgres-native-tls.workspace = true postgres-protocol.workspace = true redis.workspace = true -smol_str.workspace = true workspace_hack.workspace = true diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 03257e9161..8285da68d7 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -48,7 +48,7 @@ impl PoolingBackend { } }; let auth_outcome = - crate::auth::validate_password_and_exchange(conn_info.password.as_bytes(), secret)?; + crate::auth::validate_password_and_exchange(&conn_info.password, secret)?; match auth_outcome { crate::sasl::Outcome::Success(key) => Ok(key), crate::sasl::Outcome::Failure(reason) => { diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index f92793096b..f4e5b145c5 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -3,6 +3,7 @@ use futures::{future::poll_fn, Future}; use metrics::IntCounterPairGuard; use parking_lot::RwLock; use rand::Rng; +use smallvec::SmallVec; use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; use std::{ @@ -36,7 +37,7 @@ pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); pub struct ConnInfo { pub user_info: ComputeUserInfo, pub dbname: DbName, - pub password: SmolStr, + pub password: SmallVec<[u8; 16]>, } impl ConnInfo { @@ -731,7 +732,7 @@ mod tests { options: Default::default(), }, dbname: "dbname".into(), - password: "password".into(), + password: "password".as_bytes().into(), }; let ep_pool = Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); @@ -788,7 +789,7 @@ mod tests { options: Default::default(), }, dbname: "dbname".into(), - password: "password".into(), + password: "password".as_bytes().into(), }; let ep_pool = Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 401022347e..54424360c4 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -100,6 +100,8 @@ pub enum ConnInfoError { InvalidDbName, #[error("missing username")] MissingUsername, + #[error("invalid username: {0}")] + InvalidUsername(#[from] std::string::FromUtf8Error), #[error("missing password")] MissingPassword, #[error("missing hostname")] @@ -134,7 +136,7 @@ fn get_conn_info( let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; - let username = RoleName::from(connection_url.username()); + let username = RoleName::from(urlencoding::decode(connection_url.username())?); if username.is_empty() { return Err(ConnInfoError::MissingUsername); } @@ -143,6 +145,7 @@ fn get_conn_info( let password = connection_url .password() .ok_or(ConnInfoError::MissingPassword)?; + let password = urlencoding::decode_binary(password.as_bytes()); let hostname = connection_url .host_str() @@ -172,7 +175,10 @@ fn get_conn_info( Ok(ConnInfo { user_info, dbname: dbname.into(), - password: password.into(), + password: match password { + std::borrow::Cow::Borrowed(b) => b.into(), + std::borrow::Cow::Owned(b) => b.into(), + }, }) } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 9996853525..231eebff52 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -23,7 +23,7 @@ from itertools import chain, product from pathlib import Path from types import TracebackType from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast -from urllib.parse import urlparse +from urllib.parse import quote, urlparse import asyncpg import backoff @@ -2822,8 +2822,8 @@ class NeonProxy(PgProtocol): def http_query(self, query, args, **kwargs): # TODO maybe use default values if not provided - user = kwargs["user"] - password = kwargs["password"] + user = quote(kwargs["user"]) + password = quote(kwargs["password"]) expected_code = kwargs.get("expected_code") connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 49a0450f0c..884643cef0 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -462,6 +462,18 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): assert "password authentication failed for user" in res["message"] +def test_sql_over_http_urlencoding(static_proxy: NeonProxy): + static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser") + + static_proxy.http_query( + "select 1", + [], + user="http+auth$$", + password="%+$^&*@!", + expected_code=200, + ) + + # Beginning a transaction should not impact the next query, # which might come from a completely different client. def test_http_pool_begin(static_proxy: NeonProxy): From 1a4dd58b70ad1bf82c4daae520f4550612f91120 Mon Sep 17 00:00:00 2001 From: Sasha Krassovsky Date: Fri, 9 Feb 2024 11:22:53 -0900 Subject: [PATCH 37/81] Grant pg_monitor to neon_superuser (#6691) ## Problem The people want pg_monitor https://github.com/neondatabase/neon/issues/6682 ## Summary of changes Gives the people pg_monitor --- compute_tools/src/spec.rs | 1 + test_runner/regress/test_migrations.py | 4 ++-- test_runner/regress/test_neon_superuser.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 3df5f10e23..9c731f257c 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -776,6 +776,7 @@ BEGIN END IF; END $$;"#, + "GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION", ]; let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration"; diff --git a/test_runner/regress/test_migrations.py b/test_runner/regress/test_migrations.py index 8954810451..7cc3024ec6 100644 --- a/test_runner/regress/test_migrations.py +++ b/test_runner/regress/test_migrations.py @@ -15,7 +15,7 @@ def test_migrations(neon_simple_env: NeonEnv): endpoint.wait_for_migrations() - num_migrations = 3 + num_migrations = 4 with endpoint.cursor() as cur: cur.execute("SELECT id FROM neon_migration.migration_id") @@ -24,7 +24,7 @@ def test_migrations(neon_simple_env: NeonEnv): with open(log_path, "r") as log_file: logs = log_file.read() - assert "INFO handle_migrations: Ran 3 migrations" in logs + assert f"INFO handle_migrations: Ran {num_migrations} migrations" in logs endpoint.stop() endpoint.start() diff --git a/test_runner/regress/test_neon_superuser.py b/test_runner/regress/test_neon_superuser.py index 34f1e64b34..ca8ada4ddb 100644 --- a/test_runner/regress/test_neon_superuser.py +++ b/test_runner/regress/test_neon_superuser.py @@ -76,3 +76,21 @@ def test_neon_superuser(neon_simple_env: NeonEnv, pg_version: PgVersion): assert [r[0] for r in res] == [10, 20, 30, 40] wait_until(10, 0.5, check_that_changes_propagated) + + # Test that pg_monitor is working for neon_superuser role + cur.execute("SELECT query from pg_stat_activity LIMIT 1") + assert cur.fetchall()[0][0] != "" + # Test that pg_monitor is not working for non neon_superuser role without grant + cur.execute("CREATE ROLE not_a_superuser LOGIN PASSWORD 'Password42!'") + cur.execute("GRANT not_a_superuser TO neon_superuser WITH ADMIN OPTION") + cur.execute("SET ROLE not_a_superuser") + cur.execute("SELECT query from pg_stat_activity LIMIT 1") + assert cur.fetchall()[0][0] == "" + cur.execute("RESET ROLE") + # Test that pg_monitor is working for non neon_superuser role with grant + cur.execute("GRANT pg_monitor TO not_a_superuser") + cur.execute("SET ROLE not_a_superuser") + cur.execute("SELECT query from pg_stat_activity LIMIT 1") + assert cur.fetchall()[0][0] != "" + cur.execute("RESET ROLE") + cur.execute("DROP ROLE not_a_superuser") From 5779c7908abaadb0c96a5087423e2082101924b9 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 9 Feb 2024 23:22:40 +0100 Subject: [PATCH 38/81] revert two recent `heavier_once_cell` changes (#6704) This PR reverts - https://github.com/neondatabase/neon/pull/6589 - https://github.com/neondatabase/neon/pull/6652 because there's a performance regression that's particularly visible at high layer counts. Most likely it's because the switch to RwLock inflates the ``` inner: heavier_once_cell::OnceCell, ``` size from 48 to 88 bytes, which, by itself is almost a doubling of the cache footprint, and probably the fact that it's now larger than a cache line also doesn't help. See this chat on the Neon discord for more context: https://discord.com/channels/1176467419317940276/1204714372295958548/1205541184634617906 I'm reverting 6652 as well because it might also have perf implications, and we're getting close to the next release. We should re-do its changes after the next release, though. cc @koivunej cc @ivaxer --- libs/utils/src/sync/heavier_once_cell.rs | 322 ++++--------------- pageserver/src/tenant/storage_layer/layer.rs | 24 +- pageserver/src/tenant/timeline.rs | 2 +- 3 files changed, 81 insertions(+), 267 deletions(-) diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index 81625b907e..0ccaf4e716 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -1,6 +1,6 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, - Arc, + Arc, Mutex, MutexGuard, }; use tokio::sync::Semaphore; @@ -12,7 +12,7 @@ use tokio::sync::Semaphore; /// /// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit pub struct OnceCell { - inner: tokio::sync::RwLock>, + inner: Mutex>, initializers: AtomicUsize, } @@ -50,7 +50,7 @@ impl OnceCell { let sem = Semaphore::new(1); sem.close(); Self { - inner: tokio::sync::RwLock::new(Inner { + inner: Mutex::new(Inner { init_semaphore: Arc::new(sem), value: Some(value), }), @@ -61,113 +61,56 @@ impl OnceCell { /// Returns a guard to an existing initialized value, or uniquely initializes the value before /// returning the guard. /// - /// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization. + /// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization. /// /// Initialization is panic-safe and cancellation-safe. - pub async fn get_mut_or_init(&self, factory: F) -> Result, E> + pub async fn get_or_init(&self, factory: F) -> Result, E> where F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - loop { - let sem = { - let guard = self.inner.write().await; - if guard.value.is_some() { - return Ok(GuardMut(guard)); - } - guard.init_semaphore.clone() - }; - - { - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire().await - }; - - let Ok(permit) = permit else { - let guard = self.inner.write().await; - if !Arc::ptr_eq(&sem, &guard.init_semaphore) { - // there was a take_and_deinit in between - continue; - } - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardMut(guard)); - }; - - permit.forget(); + let sem = { + let guard = self.inner.lock().unwrap(); + if guard.value.is_some() { + return Ok(Guard(guard)); } + guard.init_semaphore.clone() + }; - let permit = InitPermit(sem); - let (value, _permit) = factory(permit).await?; + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire_owned().await + }; - let guard = self.inner.write().await; + match permit { + Ok(permit) => { + let permit = InitPermit(permit); + let (value, _permit) = factory(permit).await?; - return Ok(Self::set0(value, guard)); + let guard = self.inner.lock().unwrap(); + + Ok(Self::set0(value, guard)) + } + Err(_closed) => { + let guard = self.inner.lock().unwrap(); + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(Guard(guard)); + } } } - /// Returns a guard to an existing initialized value, or uniquely initializes the value before - /// returning the guard. - /// - /// Initialization is panic-safe and cancellation-safe. - pub async fn get_or_init(&self, factory: F) -> Result, E> - where - F: FnOnce(InitPermit) -> Fut, - Fut: std::future::Future>, - { - loop { - let sem = { - let guard = self.inner.read().await; - if guard.value.is_some() { - return Ok(GuardRef(guard)); - } - guard.init_semaphore.clone() - }; - - { - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire().await - }; - - let Ok(permit) = permit else { - let guard = self.inner.read().await; - if !Arc::ptr_eq(&sem, &guard.init_semaphore) { - // there was a take_and_deinit in between - continue; - } - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(GuardRef(guard)); - }; - - permit.forget(); - } - - let permit = InitPermit(sem); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.write().await; - - return Ok(Self::set0(value, guard).downgrade()); - } - } - - /// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used + /// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used /// to complete initializing the inner value. /// /// # Panics /// /// If the inner has already been initialized. - pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> { - let guard = self.inner.write().await; + pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> { + let guard = self.inner.lock().unwrap(); // cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot // give more permits right now. @@ -179,31 +122,21 @@ impl OnceCell { Self::set0(value, guard) } - fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner>) -> GuardMut<'_, T> { + fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner>) -> Guard<'_, T> { if guard.value.is_some() { drop(guard); unreachable!("we won permit, must not be initialized"); } guard.value = Some(value); guard.init_semaphore.close(); - GuardMut(guard) + Guard(guard) } /// Returns a guard to an existing initialized value, if any. - pub async fn get_mut(&self) -> Option> { - let guard = self.inner.write().await; + pub fn get(&self) -> Option> { + let guard = self.inner.lock().unwrap(); if guard.value.is_some() { - Some(GuardMut(guard)) - } else { - None - } - } - - /// Returns a guard to an existing initialized value, if any. - pub async fn get(&self) -> Option> { - let guard = self.inner.read().await; - if guard.value.is_some() { - Some(GuardRef(guard)) + Some(Guard(guard)) } else { None } @@ -235,9 +168,9 @@ impl<'a, T> Drop for CountWaitingInitializers<'a, T> { /// Uninteresting guard object to allow short-lived access to inspect or clone the held, /// initialized value. #[derive(Debug)] -pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner>); +pub struct Guard<'a, T>(MutexGuard<'a, Inner>); -impl std::ops::Deref for GuardMut<'_, T> { +impl std::ops::Deref for Guard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -248,7 +181,7 @@ impl std::ops::Deref for GuardMut<'_, T> { } } -impl std::ops::DerefMut for GuardMut<'_, T> { +impl std::ops::DerefMut for Guard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.0 .value @@ -257,59 +190,34 @@ impl std::ops::DerefMut for GuardMut<'_, T> { } } -impl<'a, T> GuardMut<'a, T> { +impl<'a, T> Guard<'a, T> { /// Take the current value, and a new permit for it's deinitialization. /// /// The permit will be on a semaphore part of the new internal value, and any following /// [`OnceCell::get_or_init`] will wait on it to complete. pub fn take_and_deinit(&mut self) -> (T, InitPermit) { let mut swapped = Inner::default(); - let sem = swapped.init_semaphore.clone(); - sem.try_acquire().expect("we just created this").forget(); + let permit = swapped + .init_semaphore + .clone() + .try_acquire_owned() + .expect("we just created this"); std::mem::swap(&mut *self.0, &mut swapped); swapped .value - .map(|v| (v, InitPermit(sem))) - .expect("guard is not created unless value has been initialized") - } - - pub fn downgrade(self) -> GuardRef<'a, T> { - GuardRef(self.0.downgrade()) - } -} - -#[derive(Debug)] -pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner>); - -impl std::ops::Deref for GuardRef<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.0 - .value - .as_ref() + .map(|v| (v, InitPermit(permit))) .expect("guard is not created unless value has been initialized") } } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(Arc); - -impl Drop for InitPermit { - fn drop(&mut self) { - debug_assert_eq!(self.0.available_permits(), 0); - self.0.add_permits(1); - } -} +pub struct InitPermit(tokio::sync::OwnedSemaphorePermit); #[cfg(test)] mod tests { - use futures::Future; - use super::*; use std::{ convert::Infallible, - pin::{pin, Pin}, sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -340,7 +248,7 @@ mod tests { barrier.wait().await; let won = { let g = cell - .get_mut_or_init(|permit| { + .get_or_init(|permit| { counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed); async { counters.future_polled.fetch_add(1, Ordering::Relaxed); @@ -387,11 +295,7 @@ mod tests { let cell = cell.clone(); let deinitialization_started = deinitialization_started.clone(); async move { - let (answer, _permit) = cell - .get_mut() - .await - .expect("initialized to value") - .take_and_deinit(); + let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit(); assert_eq!(answer, initial); deinitialization_started.wait().await; @@ -402,7 +306,7 @@ mod tests { deinitialization_started.wait().await; let started_at = tokio::time::Instant::now(); - cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) }) + cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) }) .await .unwrap(); @@ -414,21 +318,21 @@ mod tests { jh.await.unwrap(); - assert_eq!(*cell.get_mut().await.unwrap(), reinit); + assert_eq!(*cell.get().unwrap(), reinit); } - #[tokio::test] - async fn reinit_with_deinit_permit() { + #[test] + fn reinit_with_deinit_permit() { let cell = Arc::new(OnceCell::new(42)); - let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit(); - cell.set(5, permit).await; - assert_eq!(*cell.get_mut().await.unwrap(), 5); + let (mol, permit) = cell.get().unwrap().take_and_deinit(); + cell.set(5, permit); + assert_eq!(*cell.get().unwrap(), 5); - let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit(); + let (five, permit) = cell.get().unwrap().take_and_deinit(); assert_eq!(5, five); - cell.set(mol, permit).await; - assert_eq!(*cell.get_mut().await.unwrap(), 42); + cell.set(mol, permit); + assert_eq!(*cell.get().unwrap(), 42); } #[tokio::test] @@ -436,13 +340,13 @@ mod tests { let cell = OnceCell::default(); for _ in 0..10 { - cell.get_mut_or_init(|_permit| async { Err("whatever error") }) + cell.get_or_init(|_permit| async { Err("whatever error") }) .await .unwrap_err(); } let g = cell - .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) }) + .get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) }) .await .unwrap(); assert_eq!(*g, "finally success"); @@ -454,7 +358,7 @@ mod tests { let barrier = tokio::sync::Barrier::new(2); - let initializer = cell.get_mut_or_init(|permit| async { + let initializer = cell.get_or_init(|permit| async { barrier.wait().await; futures::future::pending::<()>().await; @@ -468,102 +372,12 @@ mod tests { // now initializer is dropped - assert!(cell.get_mut().await.is_none()); + assert!(cell.get().is_none()); let g = cell - .get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) }) + .get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) }) .await .unwrap(); assert_eq!(*g, "now initialized"); } - - #[tokio::test(start_paused = true)] - async fn reproduce_init_take_deinit_race() { - init_take_deinit_scenario(|cell, factory| { - Box::pin(async { - cell.get_or_init(factory).await.unwrap(); - }) - }) - .await; - } - - #[tokio::test(start_paused = true)] - async fn reproduce_init_take_deinit_race_mut() { - init_take_deinit_scenario(|cell, factory| { - Box::pin(async { - cell.get_mut_or_init(factory).await.unwrap(); - }) - }) - .await; - } - - type BoxedInitFuture = Pin>>>; - type BoxedInitFunction = Box BoxedInitFuture>; - - /// Reproduce an assertion failure with both initialization methods. - /// - /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`. - /// Alternative would be a macro_rules! but that is the last resort. - async fn init_take_deinit_scenario(init_way: F) - where - F: for<'a> Fn( - &'a OnceCell<&'static str>, - BoxedInitFunction<&'static str, Infallible>, - ) -> Pin + 'a>>, - { - let cell = OnceCell::default(); - - // acquire the init_semaphore only permit to drive initializing tasks in order to waiting - // on the same semaphore. - let permit = cell - .inner - .read() - .await - .init_semaphore - .clone() - .try_acquire_owned() - .unwrap(); - - let mut t1 = pin!(init_way( - &cell, - Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), - )); - - let mut t2 = pin!(init_way( - &cell, - Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), - )); - - // drive t2 first to the init_semaphore - tokio::select! { - _ = &mut t2 => unreachable!("it cannot get permit"), - _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} - } - - // followed by t1 in the init_semaphore - tokio::select! { - _ = &mut t1 => unreachable!("it cannot get permit"), - _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} - } - - // now let t2 proceed and initialize - drop(permit); - t2.await; - - let (s, permit) = { cell.get_mut().await.unwrap().take_and_deinit() }; - assert_eq!("t2", s); - - // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from - // the new one. - tokio::select! { - _ = &mut t1 => unreachable!("it cannot get permit"), - _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} - } - - // only now we get to initialize it - drop(permit); - t1.await; - - assert_eq!("t1", *cell.get().await.unwrap()); - } } diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index 52c0f8abdc..dd9de99477 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -300,8 +300,8 @@ impl Layer { }) } - pub(crate) async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo { - self.0.info(reset).await + pub(crate) fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo { + self.0.info(reset) } pub(crate) fn access_stats(&self) -> &LayerAccessStats { @@ -612,10 +612,10 @@ impl LayerInner { let mut rx = self.status.subscribe(); let strong = { - match self.inner.get_mut().await { + match self.inner.get() { Some(mut either) => { self.wanted_evicted.store(true, Ordering::Relaxed); - ResidentOrWantedEvicted::downgrade(&mut either) + either.downgrade() } None => return Err(EvictionError::NotFound), } @@ -641,7 +641,7 @@ impl LayerInner { // use however late (compared to the initial expressing of wanted) as the // "outcome" now LAYER_IMPL_METRICS.inc_broadcast_lagged(); - match self.inner.get_mut().await { + match self.inner.get() { Some(_) => Err(EvictionError::Downloaded), None => Ok(()), } @@ -759,7 +759,7 @@ impl LayerInner { // use the already held initialization permit because it is impossible to hit the // below paths anymore essentially limiting the max loop iterations to 2. let (value, init_permit) = download(init_permit).await?; - let mut guard = self.inner.set(value, init_permit).await; + let mut guard = self.inner.set(value, init_permit); let (strong, _upgraded) = guard .get_and_upgrade() .expect("init creates strong reference, we held the init permit"); @@ -767,7 +767,7 @@ impl LayerInner { } let (weak, permit) = { - let mut locked = self.inner.get_mut_or_init(download).await?; + let mut locked = self.inner.get_or_init(download).await?; if let Some((strong, upgraded)) = locked.get_and_upgrade() { if upgraded { @@ -989,12 +989,12 @@ impl LayerInner { } } - async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo { + fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo { let layer_file_name = self.desc.filename().file_name(); // this is not accurate: we could have the file locally but there was a cancellation // and now we are not in sync, or we are currently downloading it. - let remote = self.inner.get_mut().await.is_none(); + let remote = self.inner.get().is_none(); let access_stats = self.access_stats.as_api_model(reset); @@ -1053,7 +1053,7 @@ impl LayerInner { LAYER_IMPL_METRICS.inc_eviction_cancelled(EvictionCancelled::LayerGone); return; }; - match tokio::runtime::Handle::current().block_on(this.evict_blocking(version)) { + match this.evict_blocking(version) { Ok(()) => LAYER_IMPL_METRICS.inc_completed_evictions(), Err(reason) => LAYER_IMPL_METRICS.inc_eviction_cancelled(reason), } @@ -1061,7 +1061,7 @@ impl LayerInner { } } - async fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> { + fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> { // deleted or detached timeline, don't do anything. let Some(timeline) = self.timeline.upgrade() else { return Err(EvictionCancelled::TimelineGone); @@ -1070,7 +1070,7 @@ impl LayerInner { // to avoid starting a new download while we evict, keep holding on to the // permit. let _permit = { - let maybe_downloaded = self.inner.get_mut().await; + let maybe_downloaded = self.inner.get(); let (_weak, permit) = match maybe_downloaded { Some(mut guard) => { diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 735b8003b4..f96679ca69 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -1268,7 +1268,7 @@ impl Timeline { let mut historic_layers = Vec::new(); for historic_layer in layer_map.iter_historic_layers() { let historic_layer = guard.get_from_desc(&historic_layer); - historic_layers.push(historic_layer.info(reset).await); + historic_layers.push(historic_layer.info(reset)); } LayerMapInfo { From 0fd3cd27cb7ac66df5938bf219e9f12ce7b78c8a Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Fri, 9 Feb 2024 17:37:30 +0200 Subject: [PATCH 39/81] Tighten up the check for garbage after end-of-tar. Turn the warning into an error, if there is garbage after the end of imported tar file. However, it's normal for 'tar' to append extra empty blocks to the end, so tolerate those without warnings or errors. --- pageserver/src/page_service.rs | 17 ++++++++++++----- test_runner/regress/test_import.py | 10 +++------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 6fc38a76d4..7b660b5eca 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -91,8 +91,8 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000); /// `tokio_tar` already read the first such block. Read the second all-zeros block, /// and check that there is no more data after the EOF marker. /// -/// XXX: Currently, any trailing data after the EOF marker prints a warning. -/// Perhaps it should be a hard error? +/// 'tar' command can also write extra blocks of zeros, up to a record +/// size, controlled by the --record-size argument. Ignore them too. async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<()> { use tokio::io::AsyncReadExt; let mut buf = [0u8; 512]; @@ -113,17 +113,24 @@ async fn read_tar_eof(mut reader: (impl AsyncRead + Unpin)) -> anyhow::Result<() anyhow::bail!("invalid tar EOF marker"); } - // Drain any data after the EOF marker + // Drain any extra zero-blocks after the EOF marker let mut trailing_bytes = 0; + let mut seen_nonzero_bytes = false; loop { let nbytes = reader.read(&mut buf).await?; trailing_bytes += nbytes; + if !buf.iter().all(|&x| x == 0) { + seen_nonzero_bytes = true; + } if nbytes == 0 { break; } } - if trailing_bytes > 0 { - warn!("ignored {trailing_bytes} unexpected bytes after the tar archive"); + if seen_nonzero_bytes { + anyhow::bail!("unexpected non-zero bytes after the tar archive"); + } + if trailing_bytes % 512 != 0 { + anyhow::bail!("unexpected number of zeros ({trailing_bytes}), not divisible by tar block size (512 bytes), after the tar archive"); } Ok(()) } diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 3519cbbaab..7942f5cc9b 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -95,7 +95,6 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build ".*InternalServerError.*Tenant .* not found.*", ".*InternalServerError.*Timeline .* not found.*", ".*InternalServerError.*Cannot delete timeline which has child timelines.*", - ".*ignored .* unexpected bytes after the tar archive.*", ] ) @@ -142,12 +141,9 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build with pytest.raises(RuntimeError): import_tar(corrupt_base_tar, wal_tar) - # A tar with trailing garbage is currently accepted. It prints a warnings - # to the pageserver log, however. Check that. - import_tar(base_plus_garbage_tar, wal_tar) - assert env.pageserver.log_contains( - ".*WARN.*ignored .* unexpected bytes after the tar archive.*" - ) + # Importing a tar with trailing garbage fails + with pytest.raises(RuntimeError): + import_tar(base_plus_garbage_tar, wal_tar) client = env.pageserver.http_client() timeline_delete_wait_completed(client, tenant, timeline) From df5e2729a9ac3ddd80876e0d40e3ba55b95ebf0c Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Fri, 9 Feb 2024 17:37:34 +0200 Subject: [PATCH 40/81] Remove now unused allowlisted errors. I'm not sure when we stopped emitting these, but they don't seem to be needed anymore. --- test_runner/regress/test_import.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 7942f5cc9b..db385b3e73 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -98,15 +98,6 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build ] ) - env.pageserver.allowed_errors.extend( - [ - # FIXME: we should clean up pageserver to not print this - ".*exited with error: unexpected message type: CopyData.*", - # FIXME: Is this expected? - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*", - ] - ) - def import_tar(base, wal): env.neon_cli.raw_cli( [ From 12b39c9db95ec52353ab2bb3e21bc4a12306ce2b Mon Sep 17 00:00:00 2001 From: John Spray Date: Sat, 10 Feb 2024 11:56:52 +0000 Subject: [PATCH 41/81] control_plane: add debug APIs for force-dropping tenant/node (#6702) ## Problem When debugging/supporting this service, we sometimes need it to just forget about a tenant or node, e.g. because of an issue cleanly tearing them down. For example, if I create a tenant with a PlacementPolicy that can't be scheduled on the nodes we have, we would never be able to schedule it for a DELETE to work. ## Summary of changes - Add APIs for dropping nodes and tenants that do no teardown other than removing the entity from the DB and removing any references to it. --- control_plane/attachment_service/src/http.rs | 19 +++++++++ .../attachment_service/src/persistence.rs | 13 ++++++- .../attachment_service/src/service.rs | 39 +++++++++++++++++++ .../attachment_service/src/tenant_state.rs | 14 +++++++ test_runner/regress/test_sharding_service.py | 24 ++++++++++++ 5 files changed, 108 insertions(+), 1 deletion(-) diff --git a/control_plane/attachment_service/src/http.rs b/control_plane/attachment_service/src/http.rs index 8501e4980f..38785d3a98 100644 --- a/control_plane/attachment_service/src/http.rs +++ b/control_plane/attachment_service/src/http.rs @@ -280,6 +280,12 @@ async fn handle_node_list(req: Request) -> Result, ApiError json_response(StatusCode::OK, state.service.node_list().await?) } +async fn handle_node_drop(req: Request) -> Result, ApiError> { + let state = get_state(&req); + let node_id: NodeId = parse_request_param(&req, "node_id")?; + json_response(StatusCode::OK, state.service.node_drop(node_id).await?) +} + async fn handle_node_configure(mut req: Request) -> Result, ApiError> { let node_id: NodeId = parse_request_param(&req, "node_id")?; let config_req = json_request::(&mut req).await?; @@ -320,6 +326,13 @@ async fn handle_tenant_shard_migrate( ) } +async fn handle_tenant_drop(req: Request) -> Result, ApiError> { + let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; + let state = get_state(&req); + + json_response(StatusCode::OK, state.service.tenant_drop(tenant_id).await?) +} + /// Status endpoint is just used for checking that our HTTP listener is up async fn handle_status(_req: Request) -> Result, ApiError> { json_response(StatusCode::OK, ()) @@ -402,6 +415,12 @@ pub fn make_router( request_span(r, handle_attach_hook) }) .post("/debug/v1/inspect", |r| request_span(r, handle_inspect)) + .post("/debug/v1/tenant/:tenant_id/drop", |r| { + request_span(r, handle_tenant_drop) + }) + .post("/debug/v1/node/:node_id/drop", |r| { + request_span(r, handle_node_drop) + }) .get("/control/v1/tenant/:tenant_id/locate", |r| { tenant_service_handler(r, handle_tenant_locate) }) diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index 623d625767..457dc43232 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -260,7 +260,6 @@ impl Persistence { /// Ordering: call this _after_ deleting the tenant on pageservers, but _before_ dropping state for /// the tenant from memory on this server. - #[allow(unused)] pub(crate) async fn delete_tenant(&self, del_tenant_id: TenantId) -> DatabaseResult<()> { use crate::schema::tenant_shards::dsl::*; self.with_conn(move |conn| -> DatabaseResult<()> { @@ -273,6 +272,18 @@ impl Persistence { .await } + pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + self.with_conn(move |conn| -> DatabaseResult<()> { + diesel::delete(nodes) + .filter(node_id.eq(del_node_id.0 as i64)) + .execute(conn)?; + + Ok(()) + }) + .await + } + /// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient /// batched increment of the generations of all tenants whose generation_pageserver is equal to /// the node that called /re-attach. diff --git a/control_plane/attachment_service/src/service.rs b/control_plane/attachment_service/src/service.rs index 0331087e0d..95efa8ecd7 100644 --- a/control_plane/attachment_service/src/service.rs +++ b/control_plane/attachment_service/src/service.rs @@ -1804,6 +1804,45 @@ impl Service { Ok(TenantShardMigrateResponse {}) } + /// This is for debug/support only: we simply drop all state for a tenant, without + /// detaching or deleting it on pageservers. + pub(crate) async fn tenant_drop(&self, tenant_id: TenantId) -> Result<(), ApiError> { + self.persistence.delete_tenant(tenant_id).await?; + + let mut locked = self.inner.write().unwrap(); + let mut shards = Vec::new(); + for (tenant_shard_id, _) in locked.tenants.range(TenantShardId::tenant_range(tenant_id)) { + shards.push(*tenant_shard_id); + } + + for shard in shards { + locked.tenants.remove(&shard); + } + + Ok(()) + } + + /// This is for debug/support only: we simply drop all state for a tenant, without + /// detaching or deleting it on pageservers. We do not try and re-schedule any + /// tenants that were on this node. + /// + /// TODO: proper node deletion API that unhooks things more gracefully + pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> { + self.persistence.delete_node(node_id).await?; + + let mut locked = self.inner.write().unwrap(); + + for shard in locked.tenants.values_mut() { + shard.deref_node(node_id); + } + + let mut nodes = (*locked.nodes).clone(); + nodes.remove(&node_id); + locked.nodes = Arc::new(nodes); + + Ok(()) + } + pub(crate) async fn node_list(&self) -> Result, ApiError> { // It is convenient to avoid taking the big lock and converting Node to a serializable // structure, by fetching from storage instead of reading in-memory state. diff --git a/control_plane/attachment_service/src/tenant_state.rs b/control_plane/attachment_service/src/tenant_state.rs index c0ab076a55..1646ed9fcd 100644 --- a/control_plane/attachment_service/src/tenant_state.rs +++ b/control_plane/attachment_service/src/tenant_state.rs @@ -534,4 +534,18 @@ impl TenantState { seq: self.sequence, }) } + + // If we had any state at all referring to this node ID, drop it. Does not + // attempt to reschedule. + pub(crate) fn deref_node(&mut self, node_id: NodeId) { + if self.intent.attached == Some(node_id) { + self.intent.attached = None; + } + + self.intent.secondary.retain(|n| n != &node_id); + + self.observed.locations.remove(&node_id); + + debug_assert!(!self.intent.all_pageservers().contains(&node_id)); + } } diff --git a/test_runner/regress/test_sharding_service.py b/test_runner/regress/test_sharding_service.py index babb0d261c..248d992851 100644 --- a/test_runner/regress/test_sharding_service.py +++ b/test_runner/regress/test_sharding_service.py @@ -387,3 +387,27 @@ def test_sharding_service_compute_hook( assert notifications[1] == expect wait_until(10, 1, received_restart_notification) + + +def test_sharding_service_debug_apis(neon_env_builder: NeonEnvBuilder): + """ + Verify that occasional-use debug APIs work as expected. This is a lightweight test + that just hits the endpoints to check that they don't bitrot. + """ + + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_start() + + tenant_id = TenantId.generate() + env.attachment_service.tenant_create(tenant_id, shard_count=2, shard_stripe_size=8192) + + # These APIs are intentionally not implemented as methods on NeonAttachmentService, as + # they're just for use in unanticipated circumstances. + env.attachment_service.request( + "POST", f"{env.attachment_service_api}/debug/v1/node/{env.pageservers[1].id}/drop" + ) + assert len(env.attachment_service.node_list()) == 1 + + env.attachment_service.request( + "POST", f"{env.attachment_service_api}/debug/v1/tenant/{tenant_id}/drop" + ) From da626fb1facd77b1159e55c5aaa39cc28ed3ed41 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Sat, 10 Feb 2024 10:48:11 +0200 Subject: [PATCH 42/81] tests: Remove "postgres is running on ... branch" messages It seems like useless chatter. The endpoint.start() itself prints a "Running command ... neon_local endpoint start" message too. --- test_runner/regress/test_ancestor_branch.py | 2 -- test_runner/regress/test_backpressure.py | 1 - test_runner/regress/test_branch_behind.py | 1 - test_runner/regress/test_clog_truncate.py | 2 -- test_runner/regress/test_config.py | 2 -- test_runner/regress/test_createdropdb.py | 2 -- test_runner/regress/test_createuser.py | 2 -- test_runner/regress/test_ddl_forwarding.py | 1 - test_runner/regress/test_fullbackup.py | 1 - test_runner/regress/test_gc_aggressive.py | 1 - test_runner/regress/test_layer_bloating.py | 1 - test_runner/regress/test_lfc_resize.py | 1 - test_runner/regress/test_logical_replication.py | 2 -- test_runner/regress/test_lsn_mapping.py | 2 -- test_runner/regress/test_multixact.py | 3 --- test_runner/regress/test_neon_extension.py | 3 --- test_runner/regress/test_old_request_lsn.py | 1 - test_runner/regress/test_parallel_copy.py | 2 -- test_runner/regress/test_pitr_gc.py | 1 - test_runner/regress/test_read_validation.py | 2 -- test_runner/regress/test_readonly_node.py | 1 - test_runner/regress/test_recovery.py | 1 - test_runner/regress/test_subxacts.py | 8 +------- test_runner/regress/test_timeline_size.py | 6 ------ test_runner/regress/test_twophase.py | 1 - test_runner/regress/test_vm_bits.py | 2 -- 26 files changed, 1 insertion(+), 51 deletions(-) diff --git a/test_runner/regress/test_ancestor_branch.py b/test_runner/regress/test_ancestor_branch.py index 0e390ba9e5..d16d2d6a24 100644 --- a/test_runner/regress/test_ancestor_branch.py +++ b/test_runner/regress/test_ancestor_branch.py @@ -45,7 +45,6 @@ def test_ancestor_branch(neon_env_builder: NeonEnvBuilder): # Create branch1. env.neon_cli.create_branch("branch1", "main", tenant_id=tenant, ancestor_start_lsn=lsn_100) endpoint_branch1 = env.endpoints.create_start("branch1", tenant_id=tenant) - log.info("postgres is running on 'branch1' branch") branch1_cur = endpoint_branch1.connect().cursor() branch1_timeline = TimelineId(query_scalar(branch1_cur, "SHOW neon.timeline_id")) @@ -68,7 +67,6 @@ def test_ancestor_branch(neon_env_builder: NeonEnvBuilder): # Create branch2. env.neon_cli.create_branch("branch2", "branch1", tenant_id=tenant, ancestor_start_lsn=lsn_200) endpoint_branch2 = env.endpoints.create_start("branch2", tenant_id=tenant) - log.info("postgres is running on 'branch2' branch") branch2_cur = endpoint_branch2.connect().cursor() branch2_timeline = TimelineId(query_scalar(branch2_cur, "SHOW neon.timeline_id")) diff --git a/test_runner/regress/test_backpressure.py b/test_runner/regress/test_backpressure.py index bc3faf9271..819912dd05 100644 --- a/test_runner/regress/test_backpressure.py +++ b/test_runner/regress/test_backpressure.py @@ -107,7 +107,6 @@ def test_backpressure_received_lsn_lag(neon_env_builder: NeonEnvBuilder): # which is needed for backpressure_lsns() to work endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() - log.info("postgres is running on 'test_backpressure' branch") # setup check thread check_stop_event = threading.Event() diff --git a/test_runner/regress/test_branch_behind.py b/test_runner/regress/test_branch_behind.py index 9879254897..46c74a26b8 100644 --- a/test_runner/regress/test_branch_behind.py +++ b/test_runner/regress/test_branch_behind.py @@ -21,7 +21,6 @@ def test_branch_behind(neon_env_builder: NeonEnvBuilder): # Branch at the point where only 100 rows were inserted branch_behind_timeline_id = env.neon_cli.create_branch("test_branch_behind") endpoint_main = env.endpoints.create_start("test_branch_behind") - log.info("postgres is running on 'test_branch_behind' branch") main_cur = endpoint_main.connect().cursor() diff --git a/test_runner/regress/test_clog_truncate.py b/test_runner/regress/test_clog_truncate.py index f22eca02cc..26e6e336b9 100644 --- a/test_runner/regress/test_clog_truncate.py +++ b/test_runner/regress/test_clog_truncate.py @@ -25,7 +25,6 @@ def test_clog_truncate(neon_simple_env: NeonEnv): ] endpoint = env.endpoints.create_start("test_clog_truncate", config_lines=config) - log.info("postgres is running on test_clog_truncate branch") # Install extension containing function needed for test endpoint.safe_psql("CREATE EXTENSION neon_test_utils") @@ -62,7 +61,6 @@ def test_clog_truncate(neon_simple_env: NeonEnv): "test_clog_truncate_new", "test_clog_truncate", ancestor_start_lsn=lsn_after_truncation ) endpoint2 = env.endpoints.create_start("test_clog_truncate_new") - log.info("postgres is running on test_clog_truncate_new branch") # check that new node doesn't contain truncated segment pg_xact_0000_path_new = os.path.join(endpoint2.pg_xact_dir_path(), "0000") diff --git a/test_runner/regress/test_config.py b/test_runner/regress/test_config.py index 0ea5784b67..4bb7df1e6a 100644 --- a/test_runner/regress/test_config.py +++ b/test_runner/regress/test_config.py @@ -1,6 +1,5 @@ from contextlib import closing -from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv @@ -13,7 +12,6 @@ def test_config(neon_simple_env: NeonEnv): # change config endpoint = env.endpoints.create_start("test_config", config_lines=["log_min_messages=debug1"]) - log.info("postgres is running on test_config branch") with closing(endpoint.connect()) as conn: with conn.cursor() as cur: diff --git a/test_runner/regress/test_createdropdb.py b/test_runner/regress/test_createdropdb.py index 500d19cf31..f741a9fc87 100644 --- a/test_runner/regress/test_createdropdb.py +++ b/test_runner/regress/test_createdropdb.py @@ -20,7 +20,6 @@ def test_createdb(neon_simple_env: NeonEnv, strategy: str): env.neon_cli.create_branch("test_createdb", "empty") endpoint = env.endpoints.create_start("test_createdb") - log.info("postgres is running on 'test_createdb' branch") with endpoint.cursor() as cur: # Cause a 'relmapper' change in the original branch @@ -65,7 +64,6 @@ def test_dropdb(neon_simple_env: NeonEnv, test_output_dir): env = neon_simple_env env.neon_cli.create_branch("test_dropdb", "empty") endpoint = env.endpoints.create_start("test_dropdb") - log.info("postgres is running on 'test_dropdb' branch") with endpoint.cursor() as cur: cur.execute("CREATE DATABASE foodb") diff --git a/test_runner/regress/test_createuser.py b/test_runner/regress/test_createuser.py index f1bc405287..17d9824f52 100644 --- a/test_runner/regress/test_createuser.py +++ b/test_runner/regress/test_createuser.py @@ -1,4 +1,3 @@ -from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv from fixtures.utils import query_scalar @@ -10,7 +9,6 @@ def test_createuser(neon_simple_env: NeonEnv): env = neon_simple_env env.neon_cli.create_branch("test_createuser", "empty") endpoint = env.endpoints.create_start("test_createuser") - log.info("postgres is running on 'test_createuser' branch") with endpoint.cursor() as cur: # Cause a 'relmapper' change in the original branch diff --git a/test_runner/regress/test_ddl_forwarding.py b/test_runner/regress/test_ddl_forwarding.py index 7174487e68..50da673d87 100644 --- a/test_runner/regress/test_ddl_forwarding.py +++ b/test_runner/regress/test_ddl_forwarding.py @@ -296,7 +296,6 @@ def test_ddl_forwarding_invalid_db(neon_simple_env: NeonEnv): # Some non-existent url config_lines=["neon.console_url=http://localhost:9999/unknown/api/v0/roles_and_databases"], ) - log.info("postgres is running on 'test_ddl_forwarding_invalid_db' branch") with endpoint.cursor() as cur: cur.execute("SET neon.forward_ddl = false") diff --git a/test_runner/regress/test_fullbackup.py b/test_runner/regress/test_fullbackup.py index a456c06862..9a22084671 100644 --- a/test_runner/regress/test_fullbackup.py +++ b/test_runner/regress/test_fullbackup.py @@ -26,7 +26,6 @@ def test_fullbackup( env.neon_cli.create_branch("test_fullbackup") endpoint_main = env.endpoints.create_start("test_fullbackup") - log.info("postgres is running on 'test_fullbackup' branch") with endpoint_main.cursor() as cur: timeline = TimelineId(query_scalar(cur, "SHOW neon.timeline_id")) diff --git a/test_runner/regress/test_gc_aggressive.py b/test_runner/regress/test_gc_aggressive.py index ef68049ee7..c5070ee815 100644 --- a/test_runner/regress/test_gc_aggressive.py +++ b/test_runner/regress/test_gc_aggressive.py @@ -71,7 +71,6 @@ def test_gc_aggressive(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() timeline = env.neon_cli.create_branch("test_gc_aggressive", "main") endpoint = env.endpoints.create_start("test_gc_aggressive") - log.info("postgres is running on test_gc_aggressive branch") with endpoint.cursor() as cur: # Create table, and insert the first 100 rows diff --git a/test_runner/regress/test_layer_bloating.py b/test_runner/regress/test_layer_bloating.py index 70b115ad61..bf5834b665 100644 --- a/test_runner/regress/test_layer_bloating.py +++ b/test_runner/regress/test_layer_bloating.py @@ -21,7 +21,6 @@ def test_layer_bloating(neon_simple_env: NeonEnv, vanilla_pg): "test_logical_replication", config_lines=["log_statement=all"] ) - log.info("postgres is running on 'test_logical_replication' branch") pg_conn = endpoint.connect() cur = pg_conn.cursor() diff --git a/test_runner/regress/test_lfc_resize.py b/test_runner/regress/test_lfc_resize.py index 5c68a63d06..2a3442448a 100644 --- a/test_runner/regress/test_lfc_resize.py +++ b/test_runner/regress/test_lfc_resize.py @@ -23,7 +23,6 @@ def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): ) n_resize = 10 scale = 10 - log.info("postgres is running on 'test_lfc_resize' branch") def run_pgbench(connstr: str): log.info(f"Start a pgbench workload on pg {connstr}") diff --git a/test_runner/regress/test_logical_replication.py b/test_runner/regress/test_logical_replication.py index 059ddf79ec..eff0b124d3 100644 --- a/test_runner/regress/test_logical_replication.py +++ b/test_runner/regress/test_logical_replication.py @@ -26,7 +26,6 @@ def test_logical_replication(neon_simple_env: NeonEnv, vanilla_pg): "test_logical_replication", config_lines=["log_statement=all"] ) - log.info("postgres is running on 'test_logical_replication' branch") pg_conn = endpoint.connect() cur = pg_conn.cursor() @@ -315,7 +314,6 @@ def test_slots_and_branching(neon_simple_env: NeonEnv): # Create branch ws. env.neon_cli.create_branch("ws", "main", tenant_id=tenant) ws_branch = env.endpoints.create_start("ws", tenant_id=tenant) - log.info("postgres is running on 'ws' branch") # Check that we can create slot with the same name ws_cur = ws_branch.connect().cursor() diff --git a/test_runner/regress/test_lsn_mapping.py b/test_runner/regress/test_lsn_mapping.py index 50d7c74af0..5813231aab 100644 --- a/test_runner/regress/test_lsn_mapping.py +++ b/test_runner/regress/test_lsn_mapping.py @@ -28,7 +28,6 @@ def test_lsn_mapping(neon_env_builder: NeonEnvBuilder): timeline_id = env.neon_cli.create_branch("test_lsn_mapping", tenant_id=tenant_id) endpoint_main = env.endpoints.create_start("test_lsn_mapping", tenant_id=tenant_id) timeline_id = endpoint_main.safe_psql("show neon.timeline_id")[0][0] - log.info("postgres is running on 'main' branch") cur = endpoint_main.connect().cursor() @@ -114,7 +113,6 @@ def test_ts_of_lsn_api(neon_env_builder: NeonEnvBuilder): new_timeline_id = env.neon_cli.create_branch("test_ts_of_lsn_api") endpoint_main = env.endpoints.create_start("test_ts_of_lsn_api") - log.info("postgres is running on 'test_ts_of_lsn_api' branch") cur = endpoint_main.connect().cursor() # Create table, and insert rows, each in a separate transaction diff --git a/test_runner/regress/test_multixact.py b/test_runner/regress/test_multixact.py index 9db463dc4a..88f7a5db59 100644 --- a/test_runner/regress/test_multixact.py +++ b/test_runner/regress/test_multixact.py @@ -1,4 +1,3 @@ -from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content from fixtures.utils import query_scalar @@ -18,7 +17,6 @@ def test_multixact(neon_simple_env: NeonEnv, test_output_dir): env.neon_cli.create_branch("test_multixact", "empty") endpoint = env.endpoints.create_start("test_multixact") - log.info("postgres is running on 'test_multixact' branch") cur = endpoint.connect().cursor() cur.execute( """ @@ -78,7 +76,6 @@ def test_multixact(neon_simple_env: NeonEnv, test_output_dir): env.neon_cli.create_branch("test_multixact_new", "test_multixact", ancestor_start_lsn=lsn) endpoint_new = env.endpoints.create_start("test_multixact_new") - log.info("postgres is running on 'test_multixact_new' branch") next_multixact_id_new = endpoint_new.safe_psql( "SELECT next_multixact_id FROM pg_control_checkpoint()" )[0][0] diff --git a/test_runner/regress/test_neon_extension.py b/test_runner/regress/test_neon_extension.py index 998f84f968..62225e7b92 100644 --- a/test_runner/regress/test_neon_extension.py +++ b/test_runner/regress/test_neon_extension.py @@ -1,6 +1,5 @@ from contextlib import closing -from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder @@ -14,8 +13,6 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder): endpoint_main.respec(skip_pg_catalog_updates=False) endpoint_main.start() - log.info("postgres is running on 'test_create_extension_neon' branch") - with closing(endpoint_main.connect()) as conn: with conn.cursor() as cur: cur.execute("SELECT extversion from pg_extension where extname='neon'") diff --git a/test_runner/regress/test_old_request_lsn.py b/test_runner/regress/test_old_request_lsn.py index 9b0bab5125..391305c58a 100644 --- a/test_runner/regress/test_old_request_lsn.py +++ b/test_runner/regress/test_old_request_lsn.py @@ -20,7 +20,6 @@ def test_old_request_lsn(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() env.neon_cli.create_branch("test_old_request_lsn", "main") endpoint = env.endpoints.create_start("test_old_request_lsn") - log.info("postgres is running on test_old_request_lsn branch") pg_conn = endpoint.connect() cur = pg_conn.cursor() diff --git a/test_runner/regress/test_parallel_copy.py b/test_runner/regress/test_parallel_copy.py index 6f74d50b92..b33e387a66 100644 --- a/test_runner/regress/test_parallel_copy.py +++ b/test_runner/regress/test_parallel_copy.py @@ -1,7 +1,6 @@ import asyncio from io import BytesIO -from fixtures.log_helper import log from fixtures.neon_fixtures import Endpoint, NeonEnv @@ -44,7 +43,6 @@ def test_parallel_copy(neon_simple_env: NeonEnv, n_parallel=5): env = neon_simple_env env.neon_cli.create_branch("test_parallel_copy", "empty") endpoint = env.endpoints.create_start("test_parallel_copy") - log.info("postgres is running on 'test_parallel_copy' branch") # Create test table conn = endpoint.connect() diff --git a/test_runner/regress/test_pitr_gc.py b/test_runner/regress/test_pitr_gc.py index c2ea5b332a..539ef3eda7 100644 --- a/test_runner/regress/test_pitr_gc.py +++ b/test_runner/regress/test_pitr_gc.py @@ -16,7 +16,6 @@ def test_pitr_gc(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() endpoint_main = env.endpoints.create_start("main") - log.info("postgres is running on 'main' branch") main_pg_conn = endpoint_main.connect() main_cur = main_pg_conn.cursor() diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index d695410efc..effb7e83f9 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -18,7 +18,6 @@ def test_read_validation(neon_simple_env: NeonEnv): env.neon_cli.create_branch("test_read_validation", "empty") endpoint = env.endpoints.create_start("test_read_validation") - log.info("postgres is running on 'test_read_validation' branch") with closing(endpoint.connect()) as con: with con.cursor() as c: @@ -145,7 +144,6 @@ def test_read_validation_neg(neon_simple_env: NeonEnv): env.pageserver.allowed_errors.append(".*invalid LSN\\(0\\) in request.*") endpoint = env.endpoints.create_start("test_read_validation_neg") - log.info("postgres is running on 'test_read_validation_neg' branch") with closing(endpoint.connect()) as con: with con.cursor() as c: diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index 2d641e36a7..b7c8f36107 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -16,7 +16,6 @@ def test_readonly_node(neon_simple_env: NeonEnv): env = neon_simple_env env.neon_cli.create_branch("test_readonly_node", "empty") endpoint_main = env.endpoints.create_start("test_readonly_node") - log.info("postgres is running on 'test_readonly_node' branch") env.pageserver.allowed_errors.append(".*basebackup .* failed: invalid basebackup lsn.*") diff --git a/test_runner/regress/test_recovery.py b/test_runner/regress/test_recovery.py index 9d7a4a8fd6..6aac1e1d84 100644 --- a/test_runner/regress/test_recovery.py +++ b/test_runner/regress/test_recovery.py @@ -19,7 +19,6 @@ def test_pageserver_recovery(neon_env_builder: NeonEnvBuilder): env.neon_cli.create_branch("test_pageserver_recovery", "main") endpoint = env.endpoints.create_start("test_pageserver_recovery") - log.info("postgres is running on 'test_pageserver_recovery' branch") with closing(endpoint.connect()) as conn: with conn.cursor() as cur: diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index eb96a8faa4..10cb00c780 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,4 +1,3 @@ -from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content @@ -13,15 +12,10 @@ def test_subxacts(neon_simple_env: NeonEnv, test_output_dir): env.neon_cli.create_branch("test_subxacts", "empty") endpoint = env.endpoints.create_start("test_subxacts") - log.info("postgres is running on 'test_subxacts' branch") pg_conn = endpoint.connect() cur = pg_conn.cursor() - cur.execute( - """ - CREATE TABLE t1(i int, j int); - """ - ) + cur.execute("CREATE TABLE t1(i int, j int);") cur.execute("select pg_switch_wal();") diff --git a/test_runner/regress/test_timeline_size.py b/test_runner/regress/test_timeline_size.py index cd7203bba6..a3f99948d3 100644 --- a/test_runner/regress/test_timeline_size.py +++ b/test_runner/regress/test_timeline_size.py @@ -43,7 +43,6 @@ def test_timeline_size(neon_simple_env: NeonEnv): client.timeline_wait_logical_size(env.initial_tenant, new_timeline_id) endpoint_main = env.endpoints.create_start("test_timeline_size") - log.info("postgres is running on 'test_timeline_size' branch") with closing(endpoint_main.connect()) as conn: with conn.cursor() as cur: @@ -79,7 +78,6 @@ def test_timeline_size_createdropdb(neon_simple_env: NeonEnv): ) endpoint_main = env.endpoints.create_start("test_timeline_size_createdropdb") - log.info("postgres is running on 'test_timeline_size_createdropdb' branch") with closing(endpoint_main.connect()) as conn: with conn.cursor() as cur: @@ -162,8 +160,6 @@ def test_timeline_size_quota_on_startup(neon_env_builder: NeonEnvBuilder): ) endpoint_main.start() - log.info("postgres is running on 'test_timeline_size_quota_on_startup' branch") - with closing(endpoint_main.connect()) as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE foo (t text)") @@ -231,8 +227,6 @@ def test_timeline_size_quota(neon_env_builder: NeonEnvBuilder): endpoint_main.respec(skip_pg_catalog_updates=False) endpoint_main.start() - log.info("postgres is running on 'test_timeline_size_quota' branch") - with closing(endpoint_main.connect()) as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE foo (t text)") diff --git a/test_runner/regress/test_twophase.py b/test_runner/regress/test_twophase.py index 305271c715..dd76689008 100644 --- a/test_runner/regress/test_twophase.py +++ b/test_runner/regress/test_twophase.py @@ -13,7 +13,6 @@ def test_twophase(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "test_twophase", config_lines=["max_prepared_transactions=5"] ) - log.info("postgres is running on 'test_twophase' branch") conn = endpoint.connect() cur = conn.cursor() diff --git a/test_runner/regress/test_vm_bits.py b/test_runner/regress/test_vm_bits.py index 06c30b8d81..1377bed6f6 100644 --- a/test_runner/regress/test_vm_bits.py +++ b/test_runner/regress/test_vm_bits.py @@ -14,7 +14,6 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv): env.neon_cli.create_branch("test_vm_bit_clear", "empty") endpoint = env.endpoints.create_start("test_vm_bit_clear") - log.info("postgres is running on 'test_vm_bit_clear' branch") pg_conn = endpoint.connect() cur = pg_conn.cursor() @@ -93,7 +92,6 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv): # server at the right point-in-time avoids that full-page image. endpoint_new = env.endpoints.create_start("test_vm_bit_clear_new") - log.info("postgres is running on 'test_vm_bit_clear_new' branch") pg_new_conn = endpoint_new.connect() cur_new = pg_new_conn.cursor() From 241dcbf70ce117a8b956fb990f13fee67029a197 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Sat, 10 Feb 2024 10:50:52 +0200 Subject: [PATCH 43/81] tests: Remove "Running in ..." log message from every CLI call It's always the same directory, the test's "repo" directory. --- test_runner/fixtures/neon_fixtures.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 231eebff52..31acb045ae 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1400,7 +1400,6 @@ class AbstractNeonCli(abc.ABC): args = [bin_neon] + arguments log.info('Running command "{}"'.format(" ".join(args))) - log.info(f'Running in "{self.env.repo_dir}"') env_vars = os.environ.copy() env_vars["NEON_REPO_DIR"] = str(self.env.repo_dir) From d77583c86ab3cf4d5b555d86a7b665c1457f97c8 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Sat, 10 Feb 2024 11:10:48 +0200 Subject: [PATCH 44/81] tests: Remove obsolete allowlist entries Commit 9a6c0be823 removed the code that printed these warnings: marking {} as locally complete, while it doesnt exist in remote index No timelines to attach received Remove those warnings from all the allowlists in tests. --- test_runner/regress/test_compatibility.py | 5 ----- test_runner/regress/test_import.py | 5 ----- test_runner/regress/test_remote_storage.py | 3 --- test_runner/regress/test_tenant_relocation.py | 2 -- test_runner/regress/test_tenants.py | 1 - .../regress/test_tenants_with_remote_storage.py | 16 ---------------- test_runner/regress/test_wal_acceptor.py | 10 ---------- 7 files changed, 42 deletions(-) diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index d5d70951be..826821e52b 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -112,11 +112,6 @@ def test_create_snapshot( env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") - # FIXME: Is this expected? - env.pageserver.allowed_errors.append( - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*" - ) - pg_bin.run_capture(["pgbench", "--initialize", "--scale=10", endpoint.connstr()]) pg_bin.run_capture(["pgbench", "--time=60", "--progress=2", endpoint.connstr()]) pg_bin.run_capture( diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index db385b3e73..ec57860033 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -159,11 +159,6 @@ def test_import_from_pageserver_small( neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) env = neon_env_builder.init_start() - # FIXME: Is this expected? - env.pageserver.allowed_errors.append( - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*" - ) - timeline = env.neon_cli.create_branch("test_import_from_pageserver_small") endpoint = env.endpoints.create_start("test_import_from_pageserver_small") diff --git a/test_runner/regress/test_remote_storage.py b/test_runner/regress/test_remote_storage.py index 98b2e856ec..32b4f54fbd 100644 --- a/test_runner/regress/test_remote_storage.py +++ b/test_runner/regress/test_remote_storage.py @@ -73,9 +73,6 @@ def test_remote_storage_backup_and_restore( env.pageserver.allowed_errors.extend( [ - # FIXME: Is this expected? - ".*marking .* as locally complete, while it doesnt exist in remote index.*", - ".*No timelines to attach received.*", ".*Failed to get local tenant state.*", # FIXME retry downloads without throwing errors ".*failed to load remote timeline.*", diff --git a/test_runner/regress/test_tenant_relocation.py b/test_runner/regress/test_tenant_relocation.py index 80b4fab1d3..f4eb6b092d 100644 --- a/test_runner/regress/test_tenant_relocation.py +++ b/test_runner/regress/test_tenant_relocation.py @@ -213,8 +213,6 @@ def test_tenant_relocation( env.pageservers[0].allowed_errors.extend( [ - # FIXME: Is this expected? - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*", # Needed for detach polling on the original pageserver f".*NotFound: tenant {tenant_id}.*", # We will dual-attach in this test, so stale generations are expected diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index ba391a69d8..bf317808ee 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -285,7 +285,6 @@ def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder): env.pageserver.allowed_errors.extend( [ - ".*marking .* as locally complete, while it doesnt exist in remote index.*", ".*load failed.*list timelines directory.*", ] ) diff --git a/test_runner/regress/test_tenants_with_remote_storage.py b/test_runner/regress/test_tenants_with_remote_storage.py index 6f05d7f7cb..1c693a0df5 100644 --- a/test_runner/regress/test_tenants_with_remote_storage.py +++ b/test_runner/regress/test_tenants_with_remote_storage.py @@ -61,11 +61,6 @@ async def all_tenants_workload(env: NeonEnv, tenants_endpoints): def test_tenants_many(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() - # FIXME: Is this expected? - env.pageserver.allowed_errors.append( - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*" - ) - tenants_endpoints: List[Tuple[TenantId, Endpoint]] = [] for _ in range(1, 5): @@ -117,14 +112,6 @@ def test_tenants_attached_after_download(neon_env_builder: NeonEnvBuilder): ##### First start, insert secret data and upload it to the remote storage env = neon_env_builder.init_start() - env.pageserver.allowed_errors.extend( - [ - # FIXME: Are these expected? - ".*No timelines to attach received.*", - ".*marking .* as locally complete, while it doesnt exist in remote index.*", - ] - ) - pageserver_http = env.pageserver.http_client() endpoint = env.endpoints.create_start("main") @@ -223,9 +210,6 @@ def test_tenant_redownloads_truncated_file_on_startup( env.pageserver.allowed_errors.extend( [ ".*removing local file .* because .*", - # FIXME: Are these expected? - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*", - ".*No timelines to attach received.*", ] ) diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index dab446fcfd..3d7bba6153 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -280,11 +280,6 @@ def test_broker(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant timeline_id = env.neon_cli.create_branch("test_broker", "main") - # FIXME: Is this expected? - env.pageserver.allowed_errors.append( - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*" - ) - endpoint = env.endpoints.create_start("test_broker") endpoint.safe_psql("CREATE TABLE t(key int primary key, value text)") @@ -342,11 +337,6 @@ def test_wal_removal(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): neon_env_builder.auth_enabled = auth_enabled env = neon_env_builder.init_start() - # FIXME: Is this expected? - env.pageserver.allowed_errors.append( - ".*init_tenant_mgr: marking .* as locally complete, while it doesnt exist in remote index.*" - ) - tenant_id = env.initial_tenant timeline_id = env.neon_cli.create_branch("test_safekeepers_wal_removal") endpoint = env.endpoints.create_start("test_safekeepers_wal_removal") From e5daf366ac92a5398c09ea956ba03ac03848d3f8 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Sat, 10 Feb 2024 11:25:47 +0200 Subject: [PATCH 45/81] tests: Remove unnecessary port config with VanillaPostgres class VanillaPostgres constructor prints the "port={port}" line to the config file, no need to do it in the callers. The TODO comment that it would be nice if VanillaPostgres could pick the port by itself is still valid though. --- test_runner/fixtures/neon_fixtures.py | 1 + test_runner/regress/test_fullbackup.py | 6 ------ test_runner/regress/test_timeline_size.py | 1 - 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 31acb045ae..faa8effe10 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2458,6 +2458,7 @@ def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) - return PgBin(test_output_dir, pg_distrib_dir, pg_version) +# TODO make port an optional argument class VanillaPostgres(PgProtocol): def __init__(self, pgdatadir: Path, pg_bin: PgBin, port: int, init: bool = True): super().__init__(host="localhost", port=port, dbname="postgres") diff --git a/test_runner/regress/test_fullbackup.py b/test_runner/regress/test_fullbackup.py index 9a22084671..d5f898492b 100644 --- a/test_runner/regress/test_fullbackup.py +++ b/test_runner/regress/test_fullbackup.py @@ -66,12 +66,6 @@ def test_fullbackup( # Restore from the backup and find the data we inserted port = port_distributor.get_port() with VanillaPostgres(restored_dir_path, pg_bin, port, init=False) as vanilla_pg: - # TODO make port an optional argument - vanilla_pg.configure( - [ - f"port={port}", - ] - ) vanilla_pg.start() num_rows_found = vanilla_pg.safe_psql("select count(*) from tbl;", user="cloud_admin")[0][0] assert num_rows == num_rows_found diff --git a/test_runner/regress/test_timeline_size.py b/test_runner/regress/test_timeline_size.py index a3f99948d3..0788c49c7b 100644 --- a/test_runner/regress/test_timeline_size.py +++ b/test_runner/regress/test_timeline_size.py @@ -579,7 +579,6 @@ def test_timeline_size_metrics( pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version) port = port_distributor.get_port() with VanillaPostgres(pgdatadir, pg_bin, port) as vanilla_pg: - vanilla_pg.configure([f"port={port}"]) vanilla_pg.start() # Create database based on template0 because we can't connect to template0 From aeda82a0105f18393e8d56d7ff2f6202059edde6 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 11:57:29 +0200 Subject: [PATCH 46/81] fix(heavier_once_cell): assertion failure can be hit (#6722) @problame noticed that the `tokio::sync::AcquireError` branch assertion can be hit like in the added test. We haven't seen this yet in production, but I'd prefer not to see it there. There `take_and_deinit` is being used, but this race must be quite timing sensitive. Rework of earlier: #6652. --- libs/utils/src/sync/heavier_once_cell.rs | 174 ++++++++++++++++++----- 1 file changed, 138 insertions(+), 36 deletions(-) diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index 0ccaf4e716..0773abba2d 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -69,37 +69,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { + loop { + let sem = { + let guard = self.inner.lock().unwrap(); + if guard.value.is_some() { + return Ok(Guard(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.lock().unwrap(); + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(Guard(guard)); + }; + + permit.forget(); + } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + let guard = self.inner.lock().unwrap(); - if guard.value.is_some() { - return Ok(Guard(guard)); - } - guard.init_semaphore.clone() - }; - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.lock().unwrap(); - - Ok(Self::set0(value, guard)) - } - Err(_closed) => { - let guard = self.inner.lock().unwrap(); - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(Guard(guard)); - } + return Ok(Self::set0(value, guard)); } } @@ -197,27 +204,41 @@ impl<'a, T> Guard<'a, T> { /// [`OnceCell::get_or_init`] will wait on it to complete. pub fn take_and_deinit(&mut self) -> (T, InitPermit) { let mut swapped = Inner::default(); - let permit = swapped - .init_semaphore - .clone() - .try_acquire_owned() - .expect("we just created this"); + let sem = swapped.init_semaphore.clone(); + // acquire and forget right away, moving the control over to InitPermit + sem.try_acquire().expect("we just created this").forget(); std::mem::swap(&mut *self.0, &mut swapped); swapped .value - .map(|v| (v, InitPermit(permit))) + .map(|v| (v, InitPermit(sem))) .expect("guard is not created unless value has been initialized") } } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(tokio::sync::OwnedSemaphorePermit); +/// +/// On drop, this type will return the permit. +pub struct InitPermit(Arc); + +impl Drop for InitPermit { + fn drop(&mut self) { + assert_eq!( + self.0.available_permits(), + 0, + "InitPermit should only exist as the unique permit" + ); + self.0.add_permits(1); + } +} #[cfg(test)] mod tests { + use futures::Future; + use super::*; use std::{ convert::Infallible, + pin::{pin, Pin}, sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -380,4 +401,85 @@ mod tests { .unwrap(); assert_eq!(*g, "now initialized"); } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + type BoxedInitFuture = Pin>>>; + type BoxedInitFunction = Box BoxedInitFuture>; + + /// Reproduce an assertion failure. + /// + /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`. + /// We currently only have one, but the structure is kept. + async fn init_take_deinit_scenario(init_way: F) + where + F: for<'a> Fn( + &'a OnceCell<&'static str>, + BoxedInitFunction<&'static str, Infallible>, + ) -> Pin + 'a>>, + { + let cell = OnceCell::default(); + + // acquire the init_semaphore only permit to drive initializing tasks in order to waiting + // on the same semaphore. + let permit = cell + .inner + .lock() + .unwrap() + .init_semaphore + .clone() + .try_acquire_owned() + .unwrap(); + + let mut t1 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), + )); + + let mut t2 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), + )); + + // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can + // no longer make progress + tokio::select! { + _ = &mut t2 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // followed by t1 in the init_semaphore + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // now let t2 proceed and initialize + drop(permit); + t2.await; + + let (s, permit) = { cell.get().unwrap().take_and_deinit() }; + assert_eq!("t2", s); + + // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from + // the new one. + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // only now we get to initialize it + drop(permit); + t1.await; + + assert_eq!("t1", *cell.get().unwrap()); + } } From c77411e9035ac38925652bf1f772b333acb0b9ac Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 14:52:20 +0200 Subject: [PATCH 47/81] cleanup around `attach` (#6621) The smaller changes I found while looking around #6584. - rustfmt was not able to format handle_timeline_create - fix Generation::get_suffix always allocating - Generation was missing a `#[track_caller]` for panicky method - attach has a lot of issues, but even with this PR it cannot be formatted by rustfmt - moved the `preload` span to be on top of `attach` -- it is awaited inline - make disconnected panic! or unreachable! into expect, expect_err --- libs/utils/src/generation.rs | 41 ++++- pageserver/src/http/routes.rs | 76 +++++---- pageserver/src/tenant.rs | 199 +++++++++++------------ pageserver/src/tenant/delete.rs | 8 +- pageserver/src/tenant/timeline/delete.rs | 9 +- 5 files changed, 177 insertions(+), 156 deletions(-) diff --git a/libs/utils/src/generation.rs b/libs/utils/src/generation.rs index 46eadee1da..6f6c46cfeb 100644 --- a/libs/utils/src/generation.rs +++ b/libs/utils/src/generation.rs @@ -54,12 +54,10 @@ impl Generation { } #[track_caller] - pub fn get_suffix(&self) -> String { + pub fn get_suffix(&self) -> impl std::fmt::Display { match self { - Self::Valid(v) => { - format!("-{:08x}", v) - } - Self::None => "".into(), + Self::Valid(v) => GenerationFileSuffix(Some(*v)), + Self::None => GenerationFileSuffix(None), Self::Broken => { panic!("Tried to use a broken generation"); } @@ -90,6 +88,7 @@ impl Generation { } } + #[track_caller] pub fn next(&self) -> Generation { match self { Self::Valid(n) => Self::Valid(*n + 1), @@ -107,6 +106,18 @@ impl Generation { } } +struct GenerationFileSuffix(Option); + +impl std::fmt::Display for GenerationFileSuffix { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(g) = self.0 { + write!(f, "-{g:08x}") + } else { + Ok(()) + } + } +} + impl Serialize for Generation { fn serialize(&self, serializer: S) -> Result where @@ -164,4 +175,24 @@ mod test { assert!(Generation::none() < Generation::new(0)); assert!(Generation::none() < Generation::new(1)); } + + #[test] + fn suffix_is_stable() { + use std::fmt::Write as _; + + // the suffix must remain stable through-out the pageserver remote storage evolution and + // not be changed accidentially without thinking about migration + let examples = [ + (line!(), Generation::None, ""), + (line!(), Generation::Valid(0), "-00000000"), + (line!(), Generation::Valid(u32::MAX), "-ffffffff"), + ]; + + let mut s = String::new(); + for (line, gen, expected) in examples { + s.clear(); + write!(s, "{}", &gen.get_suffix()).expect("string grows"); + assert_eq!(s, expected, "example on {line}"); + } + } } diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index af9a3c7301..4be8ee9892 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -488,7 +488,9 @@ async fn timeline_create_handler( let state = get_state(&request); async { - let tenant = state.tenant_manager.get_attached_tenant_shard(tenant_shard_id, false)?; + let tenant = state + .tenant_manager + .get_attached_tenant_shard(tenant_shard_id, false)?; tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?; @@ -498,48 +500,62 @@ async fn timeline_create_handler( tracing::info!("bootstrapping"); } - match tenant.create_timeline( - new_timeline_id, - request_data.ancestor_timeline_id.map(TimelineId::from), - request_data.ancestor_start_lsn, - request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION), - request_data.existing_initdb_timeline_id, - state.broker_client.clone(), - &ctx, - ) - .await { + match tenant + .create_timeline( + new_timeline_id, + request_data.ancestor_timeline_id, + request_data.ancestor_start_lsn, + request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION), + request_data.existing_initdb_timeline_id, + state.broker_client.clone(), + &ctx, + ) + .await + { Ok(new_timeline) => { // Created. Construct a TimelineInfo for it. - let timeline_info = build_timeline_info_common(&new_timeline, &ctx, tenant::timeline::GetLogicalSizePriority::User) - .await - .map_err(ApiError::InternalServerError)?; + let timeline_info = build_timeline_info_common( + &new_timeline, + &ctx, + tenant::timeline::GetLogicalSizePriority::User, + ) + .await + .map_err(ApiError::InternalServerError)?; json_response(StatusCode::CREATED, timeline_info) } Err(_) if tenant.cancel.is_cancelled() => { // In case we get some ugly error type during shutdown, cast it into a clean 503. - json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg("Tenant shutting down".to_string())) - } - Err(tenant::CreateTimelineError::Conflict | tenant::CreateTimelineError::AlreadyCreating) => { - json_response(StatusCode::CONFLICT, ()) - } - Err(tenant::CreateTimelineError::AncestorLsn(err)) => { - json_response(StatusCode::NOT_ACCEPTABLE, HttpErrorBody::from_msg( - format!("{err:#}") - )) - } - Err(e @ tenant::CreateTimelineError::AncestorNotActive) => { - json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg(e.to_string())) - } - Err(tenant::CreateTimelineError::ShuttingDown) => { - json_response(StatusCode::SERVICE_UNAVAILABLE,HttpErrorBody::from_msg("tenant shutting down".to_string())) + json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg("Tenant shutting down".to_string()), + ) } + Err( + tenant::CreateTimelineError::Conflict + | tenant::CreateTimelineError::AlreadyCreating, + ) => json_response(StatusCode::CONFLICT, ()), + Err(tenant::CreateTimelineError::AncestorLsn(err)) => json_response( + StatusCode::NOT_ACCEPTABLE, + HttpErrorBody::from_msg(format!("{err:#}")), + ), + Err(e @ tenant::CreateTimelineError::AncestorNotActive) => json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg(e.to_string()), + ), + Err(tenant::CreateTimelineError::ShuttingDown) => json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg("tenant shutting down".to_string()), + ), Err(tenant::CreateTimelineError::Other(err)) => Err(ApiError::InternalServerError(err)), } } .instrument(info_span!("timeline_create", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), - timeline_id = %new_timeline_id, lsn=?request_data.ancestor_start_lsn, pg_version=?request_data.pg_version)) + timeline_id = %new_timeline_id, + lsn=?request_data.ancestor_start_lsn, + pg_version=?request_data.pg_version + )) .await } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 4446c410b0..d946c57118 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -644,10 +644,10 @@ impl Tenant { // The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if // we shut down while attaching. - let Ok(attach_gate_guard) = tenant.gate.enter() else { - // We just created the Tenant: nothing else can have shut it down yet - unreachable!(); - }; + let attach_gate_guard = tenant + .gate + .enter() + .expect("We just created the Tenant: nothing else can have shut it down yet"); // Do all the hard work in the background let tenant_clone = Arc::clone(&tenant); @@ -755,36 +755,27 @@ impl Tenant { AttachType::Normal }; - let preload_timer = TENANT.preload.start_timer(); - let preload = match mode { - SpawnMode::Create => { - // Don't count the skipped preload into the histogram of preload durations - preload_timer.stop_and_discard(); + let preload = match (&mode, &remote_storage) { + (SpawnMode::Create, _) => { None }, - SpawnMode::Normal => { - match &remote_storage { - Some(remote_storage) => Some( - match tenant_clone - .preload(remote_storage, task_mgr::shutdown_token()) - .instrument( - tracing::info_span!(parent: None, "attach_preload", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()), - ) - .await { - Ok(p) => { - preload_timer.observe_duration(); - p - } - , - Err(e) => { - make_broken(&tenant_clone, anyhow::anyhow!(e)); - return Ok(()); - } - }, - ), - None => None, + (SpawnMode::Normal, Some(remote_storage)) => { + let _preload_timer = TENANT.preload.start_timer(); + let res = tenant_clone + .preload(remote_storage, task_mgr::shutdown_token()) + .await; + match res { + Ok(p) => Some(p), + Err(e) => { + make_broken(&tenant_clone, anyhow::anyhow!(e)); + return Ok(()); + } } } + (SpawnMode::Normal, None) => { + let _preload_timer = TENANT.preload.start_timer(); + None + } }; // Remote preload is complete. @@ -820,36 +811,37 @@ impl Tenant { info!("ready for backgound jobs barrier"); } - match DeleteTenantFlow::resume_from_attach( + let deleted = DeleteTenantFlow::resume_from_attach( deletion, &tenant_clone, preload, tenants, &ctx, ) - .await - { - Err(err) => { - make_broken(&tenant_clone, anyhow::anyhow!(err)); - return Ok(()); - } - Ok(()) => return Ok(()), + .await; + + if let Err(e) = deleted { + make_broken(&tenant_clone, anyhow::anyhow!(e)); } + + return Ok(()); } // We will time the duration of the attach phase unless this is a creation (attach will do no work) - let attach_timer = match mode { - SpawnMode::Create => None, - SpawnMode::Normal => {Some(TENANT.attach.start_timer())} + let attached = { + let _attach_timer = match mode { + SpawnMode::Create => None, + SpawnMode::Normal => {Some(TENANT.attach.start_timer())} + }; + tenant_clone.attach(preload, mode, &ctx).await }; - match tenant_clone.attach(preload, mode, &ctx).await { + + match attached { Ok(()) => { info!("attach finished, activating"); - if let Some(t)= attach_timer {t.observe_duration();} tenant_clone.activate(broker_client, None, &ctx); } Err(e) => { - if let Some(t)= attach_timer {t.observe_duration();} make_broken(&tenant_clone, anyhow::anyhow!(e)); } } @@ -862,34 +854,26 @@ impl Tenant { // logical size calculations: if logical size calculation semaphore is saturated, // then warmup will wait for that before proceeding to the next tenant. if let AttachType::Warmup(_permit) = attach_type { - let mut futs = FuturesUnordered::new(); - let timelines: Vec<_> = tenant_clone.timelines.lock().unwrap().values().cloned().collect(); - for t in timelines { - futs.push(t.await_initial_logical_size()) - } + let mut futs: FuturesUnordered<_> = tenant_clone.timelines.lock().unwrap().values().cloned().map(|t| t.await_initial_logical_size()).collect(); tracing::info!("Waiting for initial logical sizes while warming up..."); - while futs.next().await.is_some() { - - } + while futs.next().await.is_some() {} tracing::info!("Warm-up complete"); } Ok(()) } - .instrument({ - let span = tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation)), ); Ok(tenant) } + #[instrument(skip_all)] pub(crate) async fn preload( self: &Arc, remote_storage: &GenericRemoteStorage, cancel: CancellationToken, ) -> anyhow::Result { + span::debug_assert_current_span_has_tenant_id(); // Get list of remote timelines // download index files for every tenant timeline info!("listing remote timelines"); @@ -3982,6 +3966,8 @@ pub(crate) mod harness { } } + #[cfg(test)] + #[derive(Debug)] enum LoadMode { Local, Remote, @@ -4064,7 +4050,7 @@ pub(crate) mod harness { info_span!("TenantHarness", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()) } - pub async fn load(&self) -> (Arc, RequestContext) { + pub(crate) async fn load(&self) -> (Arc, RequestContext) { let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); ( self.try_load(&ctx) @@ -4074,31 +4060,31 @@ pub(crate) mod harness { ) } - fn remote_empty(&self) -> bool { - let tenant_path = self.conf.tenant_path(&self.tenant_shard_id); - let remote_tenant_dir = self - .remote_fs_dir - .join(tenant_path.strip_prefix(&self.conf.workdir).unwrap()); - if std::fs::metadata(&remote_tenant_dir).is_err() { - return true; - } - - match std::fs::read_dir(remote_tenant_dir) - .unwrap() - .flatten() - .next() - { - Some(entry) => { - tracing::debug!( - "remote_empty: not empty, found file {}", - entry.file_name().to_string_lossy(), - ); - false - } - None => true, - } + /// For tests that specifically want to exercise the local load path, which does + /// not use remote storage. + pub(crate) async fn try_load_local( + &self, + ctx: &RequestContext, + ) -> anyhow::Result> { + self.do_try_load(ctx, LoadMode::Local).await } + /// The 'load' in this function is either a local load or a normal attachment, + pub(crate) async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result> { + // If we have nothing in remote storage, must use load_local instead of attach: attach + // will error out if there are no timelines. + // + // See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate + // this weird state of a Tenant which exists but doesn't have any timelines. + let mode = match self.remote_empty() { + true => LoadMode::Local, + false => LoadMode::Remote, + }; + + self.do_try_load(ctx, mode).await + } + + #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), ?mode))] async fn do_try_load( &self, ctx: &RequestContext, @@ -4125,20 +4111,13 @@ pub(crate) mod harness { match mode { LoadMode::Local => { - tenant - .load_local(ctx) - .instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) - .await?; + tenant.load_local(ctx).await?; } LoadMode::Remote => { let preload = tenant .preload(&self.remote_storage, CancellationToken::new()) - .instrument(info_span!("try_load_preload", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) - .await?; - tenant - .attach(Some(preload), SpawnMode::Normal, ctx) - .instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) .await?; + tenant.attach(Some(preload), SpawnMode::Normal, ctx).await?; } } @@ -4149,25 +4128,29 @@ pub(crate) mod harness { Ok(tenant) } - /// For tests that specifically want to exercise the local load path, which does - /// not use remote storage. - pub async fn try_load_local(&self, ctx: &RequestContext) -> anyhow::Result> { - self.do_try_load(ctx, LoadMode::Local).await - } + fn remote_empty(&self) -> bool { + let tenant_path = self.conf.tenant_path(&self.tenant_shard_id); + let remote_tenant_dir = self + .remote_fs_dir + .join(tenant_path.strip_prefix(&self.conf.workdir).unwrap()); + if std::fs::metadata(&remote_tenant_dir).is_err() { + return true; + } - /// The 'load' in this function is either a local load or a normal attachment, - pub async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result> { - // If we have nothing in remote storage, must use load_local instead of attach: attach - // will error out if there are no timelines. - // - // See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate - // this weird state of a Tenant which exists but doesn't have any timelines. - let mode = match self.remote_empty() { - true => LoadMode::Local, - false => LoadMode::Remote, - }; - - self.do_try_load(ctx, mode).await + match std::fs::read_dir(remote_tenant_dir) + .unwrap() + .flatten() + .next() + { + Some(entry) => { + tracing::debug!( + "remote_empty: not empty, found file {}", + entry.file_name().to_string_lossy(), + ); + false + } + None => true, + } } pub fn timeline_path(&self, timeline_id: &TimelineId) -> Utf8PathBuf { diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index 7c35914b61..0e192b577c 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -6,7 +6,7 @@ use pageserver_api::{models::TenantState, shard::TenantShardId}; use remote_storage::{GenericRemoteStorage, RemotePath}; use tokio::sync::OwnedMutexGuard; use tokio_util::sync::CancellationToken; -use tracing::{error, instrument, Instrument, Span}; +use tracing::{error, instrument, Instrument}; use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId}; @@ -496,11 +496,7 @@ impl DeleteTenantFlow { }; Ok(()) } - .instrument({ - let span = tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())), ); } diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index 88d7ce61dd..dc499197b0 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Context; use pageserver_api::{models::TimelineState, shard::TenantShardId}; use tokio::sync::OwnedMutexGuard; -use tracing::{debug, error, info, instrument, warn, Instrument, Span}; +use tracing::{debug, error, info, instrument, warn, Instrument}; use utils::{crashsafe, fs_ext, id::TimelineId}; use crate::{ @@ -541,12 +541,7 @@ impl DeleteTimelineFlow { }; Ok(()) } - .instrument({ - let span = - tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id)), ); } From 020e607637fe00ec869fd6eb71dfa732ae501b37 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:04:46 +0100 Subject: [PATCH 48/81] Proxy: copy bidirectional fork (#6720) ## Problem `tokio::io::copy_bidirectional` doesn't close the connection once one of the sides closes it. It's not really suitable for the postgres protocol. ## Summary of changes Fork `copy_bidirectional` and initiate a shutdown for both connections. --------- Co-authored-by: Conrad Ludgate --- proxy/src/proxy.rs | 1 + proxy/src/proxy/copy_bidirectional.rs | 256 ++++++++++++++++++++++++++ proxy/src/proxy/passthrough.rs | 2 +- 3 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 proxy/src/proxy/copy_bidirectional.rs diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 50e22ec72a..77aadb6f28 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,6 +2,7 @@ mod tests; pub mod connect_compute; +mod copy_bidirectional; pub mod handshake; pub mod passthrough; pub mod retry; diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs new file mode 100644 index 0000000000..2ecc1151da --- /dev/null +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -0,0 +1,256 @@ +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +#[derive(Debug)] +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +fn transfer_one_direction( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +pub(super) async fn copy_bidirectional( + a: &mut A, + b: &mut B, +) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut a_to_b = TransferState::Running(CopyBuffer::new()); + let mut b_to_a = TransferState::Running(CopyBuffer::new()); + + poll_fn(|cx| { + let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + + // Early termination checks + if let TransferState::Done(_) = a_to_b { + if let TransferState::Running(buf) = &b_to_a { + // Initiate shutdown + b_to_a = TransferState::ShuttingDown(buf.amt); + b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + } + } + if let TransferState::Done(_) = b_to_a { + if let TransferState::Running(buf) = &a_to_b { + // Initiate shutdown + a_to_b = TransferState::ShuttingDown(buf.amt); + a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + } + } + + // It is not a problem if ready! returns early ... (comment remains the same) + let a_to_b = ready!(a_to_b_result); + let b_to_a = ready!(b_to_a_result); + + Poll::Ready(Ok((a_to_b, b_to_a))) + }) + .await +} + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +const DEFAULT_BUF_SIZE: usize = 8 * 1024; + +impl CopyBuffer { + pub(super) fn new() -> Self { + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(), + } + } + + fn poll_fill_buf( + &mut self, + cx: &mut Context<'_>, + reader: Pin<&mut R>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(())) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + self.pos = 0; + self.cap = 0; + + match self.poll_fill_buf(cx, reader.as_mut()) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn test_early_termination_a_to_d() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + a_mock.write_all(b"hello").await.unwrap(); + a_mock.shutdown().await.unwrap(); + d_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(a_to_d_count, 5); // 'hello' was transferred + assert!(d_to_a_count <= 8); // response only partially transferred or not at all + } + + #[tokio::test] + async fn test_early_termination_d_to_a() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + d_mock.write_all(b"hello").await.unwrap(); + d_mock.shutdown().await.unwrap(); + a_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(d_to_a_count, 5); // 'hello' was transferred + assert!(a_to_d_count <= 8); // response only partially transferred or not at all + } +} diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b7018c6fb5..c98f68d8d1 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -45,7 +45,7 @@ pub async fn proxy_pass( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?; Ok(()) } From 98ec5c5c466158fcb10394303077132efa680690 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Feb 2024 13:14:06 +0000 Subject: [PATCH 49/81] proxy: some more parquet data (#6711) ## Summary of changes add auth_method and database to the parquet logs --- proxy/src/auth/backend.rs | 8 ++-- proxy/src/auth/backend/classic.rs | 8 ++-- proxy/src/auth/backend/hacks.rs | 12 +++-- proxy/src/auth/backend/link.rs | 2 + proxy/src/auth/credentials.rs | 3 ++ proxy/src/auth/flow.rs | 17 ++++++- proxy/src/context.rs | 23 ++++++++- proxy/src/context/parquet.rs | 69 ++++++++++++++++----------- proxy/src/proxy/tests.rs | 2 +- proxy/src/serverless/sql_over_http.rs | 9 +++- 10 files changed, 104 insertions(+), 49 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index fa2782bee3..c9f21f1cf5 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -194,8 +194,7 @@ async fn auth_quirks( // We now expect to see a very specific payload in the place of password. let (info, unauthenticated_password) = match user_info.try_into() { Err(info) => { - let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) - .await?; + let res = hacks::password_hack_no_authentication(ctx, info, client).await?; ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); @@ -276,11 +275,12 @@ async fn authenticate_with_secret( // Perform cleartext auth if we're allowed to do that. // Currently, we use it for websocket connections (latency). if allow_cleartext { - return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await; + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + return hacks::authenticate_cleartext(ctx, info, client, secret).await; } // Finally, proceed with the main auth flow (SCRAM-based). - classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await + classic::authenticate(ctx, info, client, config, secret).await } impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 745dd75107..e855843bc3 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -4,7 +4,7 @@ use crate::{ compute, config::AuthenticationConfig, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{PqStream, Stream}, }; @@ -12,10 +12,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( + ctx: &mut RequestMonitoring, creds: ComputeUserInfo, client: &mut PqStream>, config: &'static AuthenticationConfig, - latency_timer: &mut LatencyTimer, secret: AuthSecret, ) -> auth::Result> { let flow = AuthFlow::new(client); @@ -27,13 +27,11 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { info!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret); + let scram = auth::Scram(&secret, &mut *ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, async { - // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); flow.begin(scram).await.map_err(|error| { warn!(?error, "error sending scram acknowledgement"); diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index b6c1a92d3c..9f60b709d4 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -4,7 +4,7 @@ use super::{ use crate::{ auth::{self, AuthFlow}, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{self, Stream}, }; @@ -16,15 +16,16 @@ use tracing::{info, warn}; /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. pub async fn authenticate_cleartext( + ctx: &mut RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, secret: AuthSecret, ) -> auth::Result> { warn!("cleartext auth flow override is enabled, proceeding"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let auth_outcome = AuthFlow::new(client) .begin(auth::CleartextPassword(secret)) @@ -47,14 +48,15 @@ pub async fn authenticate_cleartext( /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) pub async fn password_hack_no_authentication( + ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, ) -> auth::Result>> { warn!("project not specified, resorting to the password hack auth flow"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index c71637dd1a..bf9ebf4c18 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -61,6 +61,8 @@ pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { + ctx.set_auth_method(crate::context::AuthMethod::Web); + // registering waiter can fail if we get unlucky with rng. // just try again. let (psql_session_id, waiter) = loop { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d32609e44c..d318b3be54 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint { // record the values if we have them ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_user(user.clone()); + if let Some(dbname) = params.get("database") { + ctx.set_dbname(dbname.into()); + } // Project name might be passed via PG's command-line options. let endpoint_option = params diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c2783e236c..dce73138c6 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ config::TlsServerEndPoint, console::AuthSecret, + context::RequestMonitoring, sasl, scram, stream::{PqStream, Stream}, }; +use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; @@ -23,7 +25,7 @@ pub trait AuthMethod { pub struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret); +pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -138,6 +140,11 @@ impl AuthFlow<'_, S, CleartextPassword> { impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result> { + let Scram(secret, ctx) = self.state; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer.pause(); + // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg) @@ -148,9 +155,15 @@ impl AuthFlow<'_, S, Scram<'_>> { return Err(super::AuthError::bad_auth_method(sasl.method)); } + match sasl.method { + SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus) + } + _ => {} + } info!("client chooses {}", sasl.method); - let secret = self.state.0; let outcome = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new( secret, diff --git a/proxy/src/context.rs b/proxy/src/context.rs index d2bf3f68d3..0cea53ae63 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -11,7 +11,7 @@ use crate::{ console::messages::MetricsAuxInfo, error::ErrorKind, metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, - BranchId, EndpointId, ProjectId, RoleName, + BranchId, DbName, EndpointId, ProjectId, RoleName, }; pub mod parquet; @@ -34,9 +34,11 @@ pub struct RequestMonitoring { project: Option, branch: Option, endpoint_id: Option, + dbname: Option, user: Option, application: Option, error_kind: Option, + pub(crate) auth_method: Option, success: bool, // extra @@ -45,6 +47,15 @@ pub struct RequestMonitoring { pub latency_timer: LatencyTimer, } +#[derive(Clone, Debug)] +pub enum AuthMethod { + // aka link aka passwordless + Web, + ScramSha256, + ScramSha256Plus, + Cleartext, +} + impl RequestMonitoring { pub fn new( session_id: Uuid, @@ -62,9 +73,11 @@ impl RequestMonitoring { project: None, branch: None, endpoint_id: None, + dbname: None, user: None, application: None, error_kind: None, + auth_method: None, success: false, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -106,10 +119,18 @@ impl RequestMonitoring { self.application = app.or_else(|| self.application.clone()); } + pub fn set_dbname(&mut self, dbname: DbName) { + self.dbname = Some(dbname); + } + pub fn set_user(&mut self, user: RoleName) { self.user = Some(user); } + pub fn set_auth_method(&mut self, auth_method: AuthMethod) { + self.auth_method = Some(auth_method); + } + pub fn set_error_kind(&mut self, kind: ErrorKind) { ERROR_BY_KIND .with_label_values(&[kind.to_metric_label()]) diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 0fe46915bc..ad22829183 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -84,8 +84,10 @@ struct RequestData { username: Option, application_name: Option, endpoint_id: Option, + database: Option, project: Option, branch: Option, + auth_method: Option<&'static str>, error: Option<&'static str>, /// Success is counted if we form a HTTP response with sql rows inside /// Or if we make it to proxy_pass @@ -104,8 +106,15 @@ impl From for RequestData { username: value.user.as_deref().map(String::from), application_name: value.application.as_deref().map(String::from), endpoint_id: value.endpoint_id.as_deref().map(String::from), + database: value.dbname.as_deref().map(String::from), project: value.project.as_deref().map(String::from), branch: value.branch.as_deref().map(String::from), + auth_method: value.auth_method.as_ref().map(|x| match x { + super::AuthMethod::Web => "web", + super::AuthMethod::ScramSha256 => "scram_sha_256", + super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus", + super::AuthMethod::Cleartext => "cleartext", + }), protocol: value.protocol, region: value.region, error: value.error_kind.as_ref().map(|e| e.to_metric_label()), @@ -431,8 +440,10 @@ mod tests { application_name: Some("test".to_owned()), username: Some(hex::encode(rng.gen::<[u8; 4]>())), endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())), + database: Some(hex::encode(rng.gen::<[u8; 16]>())), project: Some(hex::encode(rng.gen::<[u8; 16]>())), branch: Some(hex::encode(rng.gen::<[u8; 16]>())), + auth_method: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], region: "us-east-1", error: None, @@ -505,15 +516,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -543,11 +554,11 @@ mod tests { assert_eq!( file_stats, [ - (1028637, 5, 10000), - (1031969, 5, 10000), - (1019900, 5, 10000), - (1020365, 5, 10000), - (1025010, 5, 10000) + (1219459, 5, 10000), + (1225609, 5, 10000), + (1227403, 5, 10000), + (1226765, 5, 10000), + (1218043, 5, 10000) ], ); @@ -579,11 +590,11 @@ mod tests { assert_eq!( file_stats, [ - (1210770, 6, 12000), - (1211036, 6, 12000), - (1210990, 6, 12000), - (1210861, 6, 12000), - (202073, 1, 2000) + (1205106, 5, 10000), + (1204837, 5, 10000), + (1205130, 5, 10000), + (1205118, 5, 10000), + (1205373, 5, 10000) ], ); @@ -608,15 +619,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -653,7 +664,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)], + [(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)], ); tmpdir.close().unwrap(); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3e961afb41..5bb43c0375 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -144,7 +144,7 @@ impl TestAuth for Scram { stream: &mut PqStream>, ) -> anyhow::Result<()> { let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0)) + .begin(auth::Scram(&self.0, &mut RequestMonitoring::test())) .await? .authenticate() .await?; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 54424360c4..e9f868d51e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -36,6 +36,7 @@ use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::DbName; use crate::RoleName; use super::backend::PoolingBackend; @@ -117,6 +118,9 @@ fn get_conn_info( headers: &HeaderMap, tls: &TlsConfig, ) -> Result { + // HTTP only uses cleartext (for now and likely always) + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + let connection_string = headers .get("Neon-Connection-String") .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? @@ -134,7 +138,8 @@ fn get_conn_info( .path_segments() .ok_or(ConnInfoError::MissingDbName)?; - let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; + let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into(); + ctx.set_dbname(dbname.clone()); let username = RoleName::from(urlencoding::decode(connection_url.username())?); if username.is_empty() { @@ -174,7 +179,7 @@ fn get_conn_info( Ok(ConnInfo { user_info, - dbname: dbname.into(), + dbname, password: match password { std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(), From 242dd8398c8d6728270c8d8c2a0b45dae480cb97 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 12 Feb 2024 15:58:55 +0100 Subject: [PATCH 50/81] refactor(blob_io): use owned buffers (#6660) This PR refactors the `blob_io` code away from using slices towards taking owned buffers and return them after use. Using owned buffers will eventually allow us to use io_uring for writes. part of https://github.com/neondatabase/neon/issues/6663 Depends on https://github.com/neondatabase/tokio-epoll-uring/pull/43 The high level scheme is as follows: - call writing functions with the `BoundedBuf` - return the underlying `BoundedBuf::Buf` for potential reuse in the caller NB: Invoking `BoundedBuf::slice(..)` will return a slice that _includes the uninitialized portion of `BoundedBuf`_. I.e., the portion between `bytes_init()` and `bytes_total()`. It's a safe API that actually permits access to uninitialized memory. Not great. Another wrinkle is that it panics if the range has length 0. However, I don't want to switch away from the `BoundedBuf` API, since it's what tokio-uring uses. We can always weed this out later by replacing `BoundedBuf` with our own type. Created an issue so we don't forget: https://github.com/neondatabase/tokio-epoll-uring/issues/46 --- Cargo.lock | 5 +- pageserver/src/tenant/blob_io.rs | 121 +++++++++++++----- .../src/tenant/storage_layer/delta_layer.rs | 26 ++-- .../src/tenant/storage_layer/image_layer.rs | 8 +- .../tenant/storage_layer/inmemory_layer.rs | 8 +- pageserver/src/tenant/timeline.rs | 2 +- 6 files changed, 115 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83afdaf66f..520163e41b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5740,7 +5740,7 @@ dependencies = [ [[package]] name = "tokio-epoll-uring" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e" dependencies = [ "futures", "nix 0.26.4", @@ -6265,8 +6265,9 @@ dependencies = [ [[package]] name = "uring-common" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e" dependencies = [ + "bytes", "io-uring", "libc", ] diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index 6de2e95055..e2ff12665a 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -11,6 +11,9 @@ //! len < 128: 0XXXXXXX //! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX //! +use bytes::{BufMut, BytesMut}; +use tokio_epoll_uring::{BoundedBuf, Slice}; + use crate::context::RequestContext; use crate::page_cache::PAGE_SZ; use crate::tenant::block_io::BlockCursor; @@ -100,6 +103,8 @@ pub struct BlobWriter { offset: u64, /// A buffer to save on write calls, only used if BUFFERED=true buf: Vec, + /// We do tiny writes for the length headers; they need to be in an owned buffer; + io_buf: Option, } impl BlobWriter { @@ -108,6 +113,7 @@ impl BlobWriter { inner, offset: start_offset, buf: Vec::with_capacity(Self::CAPACITY), + io_buf: Some(BytesMut::new()), } } @@ -117,14 +123,28 @@ impl BlobWriter { const CAPACITY: usize = if BUFFERED { PAGE_SZ } else { 0 }; - #[inline(always)] /// Writes the given buffer directly to the underlying `VirtualFile`. /// You need to make sure that the internal buffer is empty, otherwise /// data will be written in wrong order. - async fn write_all_unbuffered(&mut self, src_buf: &[u8]) -> Result<(), Error> { - self.inner.write_all(src_buf).await?; - self.offset += src_buf.len() as u64; - Ok(()) + #[inline(always)] + async fn write_all_unbuffered( + &mut self, + src_buf: B, + ) -> (B::Buf, Result<(), Error>) { + let src_buf_len = src_buf.bytes_init(); + let (src_buf, res) = if src_buf_len > 0 { + let src_buf = src_buf.slice(0..src_buf_len); + let res = self.inner.write_all(&src_buf).await; + let src_buf = Slice::into_inner(src_buf); + (src_buf, res) + } else { + let res = self.inner.write_all(&[]).await; + (Slice::into_inner(src_buf.slice_full()), res) + }; + if let Ok(()) = &res { + self.offset += src_buf_len as u64; + } + (src_buf, res) } #[inline(always)] @@ -146,62 +166,91 @@ impl BlobWriter { } /// Internal, possibly buffered, write function - async fn write_all(&mut self, mut src_buf: &[u8]) -> Result<(), Error> { + async fn write_all(&mut self, src_buf: B) -> (B::Buf, Result<(), Error>) { if !BUFFERED { assert!(self.buf.is_empty()); - self.write_all_unbuffered(src_buf).await?; - return Ok(()); + return self.write_all_unbuffered(src_buf).await; } let remaining = Self::CAPACITY - self.buf.len(); + let src_buf_len = src_buf.bytes_init(); + if src_buf_len == 0 { + return (Slice::into_inner(src_buf.slice_full()), Ok(())); + } + let mut src_buf = src_buf.slice(0..src_buf_len); // First try to copy as much as we can into the buffer if remaining > 0 { - let copied = self.write_into_buffer(src_buf); - src_buf = &src_buf[copied..]; + let copied = self.write_into_buffer(&src_buf); + src_buf = src_buf.slice(copied..); } // Then, if the buffer is full, flush it out if self.buf.len() == Self::CAPACITY { - self.flush_buffer().await?; + if let Err(e) = self.flush_buffer().await { + return (Slice::into_inner(src_buf), Err(e)); + } } // Finally, write the tail of src_buf: // If it wholly fits into the buffer without // completely filling it, then put it there. // If not, write it out directly. - if !src_buf.is_empty() { + let src_buf = if !src_buf.is_empty() { assert_eq!(self.buf.len(), 0); if src_buf.len() < Self::CAPACITY { - let copied = self.write_into_buffer(src_buf); + let copied = self.write_into_buffer(&src_buf); // We just verified above that src_buf fits into our internal buffer. assert_eq!(copied, src_buf.len()); + Slice::into_inner(src_buf) } else { - self.write_all_unbuffered(src_buf).await?; + let (src_buf, res) = self.write_all_unbuffered(src_buf).await; + if let Err(e) = res { + return (src_buf, Err(e)); + } + src_buf } - } - Ok(()) + } else { + Slice::into_inner(src_buf) + }; + (src_buf, Ok(())) } /// Write a blob of data. Returns the offset that it was written to, /// which can be used to retrieve the data later. - pub async fn write_blob(&mut self, srcbuf: &[u8]) -> Result { + pub async fn write_blob(&mut self, srcbuf: B) -> (B::Buf, Result) { let offset = self.offset; - if srcbuf.len() < 128 { - // Short blob. Write a 1-byte length header - let len_buf = srcbuf.len() as u8; - self.write_all(&[len_buf]).await?; - } else { - // Write a 4-byte length header - if srcbuf.len() > 0x7fff_ffff { - return Err(Error::new( - ErrorKind::Other, - format!("blob too large ({} bytes)", srcbuf.len()), - )); + let len = srcbuf.bytes_init(); + + let mut io_buf = self.io_buf.take().expect("we always put it back below"); + io_buf.clear(); + let (io_buf, hdr_res) = async { + if len < 128 { + // Short blob. Write a 1-byte length header + io_buf.put_u8(len as u8); + self.write_all(io_buf).await + } else { + // Write a 4-byte length header + if len > 0x7fff_ffff { + return ( + io_buf, + Err(Error::new( + ErrorKind::Other, + format!("blob too large ({} bytes)", len), + )), + ); + } + let mut len_buf = (len as u32).to_be_bytes(); + len_buf[0] |= 0x80; + io_buf.extend_from_slice(&len_buf[..]); + self.write_all(io_buf).await } - let mut len_buf = ((srcbuf.len()) as u32).to_be_bytes(); - len_buf[0] |= 0x80; - self.write_all(&len_buf).await?; } - self.write_all(srcbuf).await?; - Ok(offset) + .await; + self.io_buf = Some(io_buf); + match hdr_res { + Ok(_) => (), + Err(e) => return (Slice::into_inner(srcbuf.slice(..)), Err(e)), + } + let (srcbuf, res) = self.write_all(srcbuf).await; + (srcbuf, res.map(|_| offset)) } } @@ -248,12 +297,14 @@ mod tests { let file = VirtualFile::create(pathbuf.as_path()).await?; let mut wtr = BlobWriter::::new(file, 0); for blob in blobs.iter() { - let offs = wtr.write_blob(blob).await?; + let (_, res) = wtr.write_blob(blob.clone()).await; + let offs = res?; offsets.push(offs); } // Write out one page worth of zeros so that we can // read again with read_blk - let offs = wtr.write_blob(&vec![0; PAGE_SZ]).await?; + let (_, res) = wtr.write_blob(vec![0; PAGE_SZ]).await; + let offs = res?; println!("Writing final blob at offs={offs}"); wtr.flush_buffer().await?; } diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 2a51884c0b..7a5dc7a59f 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -416,27 +416,31 @@ impl DeltaLayerWriterInner { /// The values must be appended in key, lsn order. /// async fn put_value(&mut self, key: Key, lsn: Lsn, val: Value) -> anyhow::Result<()> { - self.put_value_bytes(key, lsn, &Value::ser(&val)?, val.will_init()) - .await + let (_, res) = self + .put_value_bytes(key, lsn, Value::ser(&val)?, val.will_init()) + .await; + res } async fn put_value_bytes( &mut self, key: Key, lsn: Lsn, - val: &[u8], + val: Vec, will_init: bool, - ) -> anyhow::Result<()> { + ) -> (Vec, anyhow::Result<()>) { assert!(self.lsn_range.start <= lsn); - - let off = self.blob_writer.write_blob(val).await?; + let (val, res) = self.blob_writer.write_blob(val).await; + let off = match res { + Ok(off) => off, + Err(e) => return (val, Err(anyhow::anyhow!(e))), + }; let blob_ref = BlobRef::new(off, will_init); let delta_key = DeltaKey::from_key_lsn(&key, lsn); - self.tree.append(&delta_key.0, blob_ref.0)?; - - Ok(()) + let res = self.tree.append(&delta_key.0, blob_ref.0); + (val, res.map_err(|e| anyhow::anyhow!(e))) } fn size(&self) -> u64 { @@ -587,9 +591,9 @@ impl DeltaLayerWriter { &mut self, key: Key, lsn: Lsn, - val: &[u8], + val: Vec, will_init: bool, - ) -> anyhow::Result<()> { + ) -> (Vec, anyhow::Result<()>) { self.inner .as_mut() .unwrap() diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index c62e6aed51..1ad195032d 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -528,9 +528,11 @@ impl ImageLayerWriterInner { /// /// The page versions must be appended in blknum order. /// - async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> { + async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> { ensure!(self.key_range.contains(&key)); - let off = self.blob_writer.write_blob(img).await?; + let (_img, res) = self.blob_writer.write_blob(img).await; + // TODO: re-use the buffer for `img` further upstack + let off = res?; let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE]; key.write_to_byte_slice(&mut keybuf); @@ -659,7 +661,7 @@ impl ImageLayerWriter { /// /// The page versions must be appended in blknum order. /// - pub async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> { + pub async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> { self.inner.as_mut().unwrap().put_image(key, img).await } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index 7c9103eea8..c597b15533 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -383,9 +383,11 @@ impl InMemoryLayer { for (lsn, pos) in vec_map.as_slice() { cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?; let will_init = Value::des(&buf)?.will_init(); - delta_layer_writer - .put_value_bytes(key, *lsn, &buf, will_init) - .await?; + let res; + (buf, res) = delta_layer_writer + .put_value_bytes(key, *lsn, buf, will_init) + .await; + res?; } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index f96679ca69..74676277d5 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -3328,7 +3328,7 @@ impl Timeline { } }; - image_layer_writer.put_image(img_key, &img).await?; + image_layer_writer.put_image(img_key, img).await?; } } From 789a71c4ee6722f26ae4929a10e1316568e2006f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Feb 2024 15:03:45 +0000 Subject: [PATCH 51/81] proxy: add more http logging (#6726) ## Problem hard to see where time is taken during HTTP flow. ## Summary of changes add a lot more for query state. add a conn_id field to the sql-over-http span --- proxy/src/metrics.rs | 5 ++-- proxy/src/serverless/backend.rs | 8 +++---- proxy/src/serverless/conn_pool.rs | 22 +++++------------- proxy/src/serverless/sql_over_http.rs | 33 +++++++++++++++++++++++---- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index ccf89f9b05..f7f162a075 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -200,8 +200,9 @@ impl LatencyTimer { pub fn success(&mut self) { // stop the stopwatch and record the time that we have accumulated - let start = self.start.take().expect("latency timer should be started"); - self.accumulated += start.elapsed(); + if let Some(start) = self.start.take() { + self.accumulated += start.elapsed(); + } // success self.outcome = "success"; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8285da68d7..156002006d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,7 +1,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use tracing::info; +use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, @@ -15,7 +15,7 @@ use crate::{ proxy::connect_compute::ConnectMechanism, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME}; +use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; pub struct PoolingBackend { pub pool: Arc>, @@ -81,8 +81,8 @@ impl PoolingBackend { return Ok(client); } let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - ctx.set_application(Some(APP_NAME)); + tracing::Span::current().record("conn_id", display(conn_id)); + info!("pool: opening a new connection '{conn_info}'"); let backend = self .config .auth_backend diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index f4e5b145c5..53e7c1c2ee 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard; use parking_lot::RwLock; use rand::Rng; use smallvec::SmallVec; -use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; use std::{ fmt, @@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; -pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); - #[derive(Debug, Clone)] pub struct ConnInfo { pub user_info: ComputeUserInfo, @@ -379,12 +376,13 @@ impl GlobalConnPool { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); } else { - info!("pool: reusing connection '{conn_info}'"); - client.session.send(ctx.session_id)?; + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); tracing::Span::current().record( "pid", &tracing::field::display(client.inner.get_process_id()), ); + info!("pool: reusing connection '{conn_info}'"); + client.session.send(ctx.session_id)?; ctx.latency_timer.pool_hit(); ctx.latency_timer.success(); return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); @@ -577,7 +575,6 @@ pub struct Client { } pub struct Discard<'a, C: ClientInnerExt> { - conn_id: uuid::Uuid, conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } @@ -603,14 +600,7 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - ( - &mut inner.inner, - Discard { - pool, - conn_info, - conn_id: inner.conn_id, - }, - ) + (&mut inner.inner, Discard { pool, conn_info }) } pub fn check_idle(&mut self, status: ReadyForQueryStatus) { @@ -625,13 +615,13 @@ impl Discard<'_, C> { pub fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle") + info!("pool: throwing away connection '{conn_info}' because connection is not idle") } } pub fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index e9f868d51e..ecb72abe73 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -36,6 +36,7 @@ use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::serverless::backend::HttpConnError; use crate::DbName; use crate::RoleName; @@ -305,7 +306,14 @@ pub async fn handle( Ok(response) } -#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] +#[instrument( + name = "sql-over-http", + skip_all, + fields( + pid = tracing::field::Empty, + conn_id = tracing::field::Empty + ) +)] async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, @@ -359,12 +367,10 @@ async fn handle_inner( let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE); let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE); - let paused = ctx.latency_timer.pause(); let request_content_length = match request.body().size_hint().upper() { Some(v) => v, None => MAX_REQUEST_SIZE + 1, }; - drop(paused); info!(request_content_length, "request size in bytes"); HTTP_CONTENT_LENGTH.observe(request_content_length as f64); @@ -380,15 +386,20 @@ async fn handle_inner( let body = hyper::body::to_bytes(request.into_body()) .await .map_err(anyhow::Error::from)?; + info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly }; let authenticate_and_connect = async { let keys = backend.authenticate(ctx, &conn_info).await?; - backend + let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await + .await?; + // not strictly necessary to mark success here, + // but it's just insurance for if we forget it somewhere else + ctx.latency_timer.success(); + Ok::<_, HttpConnError>(client) }; // Run both operations in parallel @@ -420,6 +431,7 @@ async fn handle_inner( results } Payload::Batch(statements) => { + info!("starting transaction"); let (inner, mut discard) = client.inner(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = txn_isolation_level { @@ -449,6 +461,7 @@ async fn handle_inner( .await { Ok(results) => { + info!("commit"); let status = transaction.commit().await.map_err(|e| { // if we cannot commit - for now don't return connection to pool // TODO: get a query status from the error @@ -459,6 +472,7 @@ async fn handle_inner( results } Err(err) => { + info!("rollback"); let status = transaction.rollback().await.map_err(|e| { // if we cannot rollback - for now don't return connection to pool // TODO: get a query status from the error @@ -533,8 +547,10 @@ async fn query_to_json( raw_output: bool, default_array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { + info!("executing query"); let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; + info!("finished executing query"); // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too @@ -569,6 +585,13 @@ async fn query_to_json( } .and_then(|s| s.parse::().ok()); + info!( + rows = rows.len(), + ?ready, + command_tag, + "finished reading rows" + ); + let mut fields = vec![]; let mut columns = vec![]; From 7ea593db2292324e136d3325cd96217c9d652395 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 17:13:35 +0200 Subject: [PATCH 52/81] refactor(LayerManager): resident layers query (#6634) Refactor out layer accesses so that we can have easy access to resident layers, which are needed for number of cases instead of layers for eviction. Simplifies the heatmap building by only using Layers, not RemoteTimelineClient. Cc: #5331 --- .../src/tenant/remote_timeline_client.rs | 17 ---- pageserver/src/tenant/storage_layer.rs | 8 +- pageserver/src/tenant/storage_layer/layer.rs | 4 - pageserver/src/tenant/timeline.rs | 97 ++++++------------- .../src/tenant/timeline/eviction_task.rs | 7 +- .../src/tenant/timeline/layer_manager.rs | 45 ++++++--- 6 files changed, 74 insertions(+), 104 deletions(-) diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index e17dea01a8..483f53d5c8 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1700,23 +1700,6 @@ impl RemoteTimelineClient { } } } - - pub(crate) fn get_layers_metadata( - &self, - layers: Vec, - ) -> anyhow::Result>> { - let q = self.upload_queue.lock().unwrap(); - let q = match &*q { - UploadQueue::Stopped(_) | UploadQueue::Uninitialized => { - anyhow::bail!("queue is in state {}", q.as_str()) - } - UploadQueue::Initialized(inner) => inner, - }; - - let decorated = layers.into_iter().map(|l| q.latest_files.get(&l).cloned()); - - Ok(decorated.collect()) - } } pub fn remote_timelines_path(tenant_shard_id: &TenantShardId) -> RemotePath { diff --git a/pageserver/src/tenant/storage_layer.rs b/pageserver/src/tenant/storage_layer.rs index 6e9a4932d8..2d92baccbe 100644 --- a/pageserver/src/tenant/storage_layer.rs +++ b/pageserver/src/tenant/storage_layer.rs @@ -257,6 +257,12 @@ impl LayerAccessStats { ret } + /// Get the latest access timestamp, falling back to latest residence event, further falling + /// back to `SystemTime::now` for a usable timestamp for eviction. + pub(crate) fn latest_activity_or_now(&self) -> SystemTime { + self.latest_activity().unwrap_or_else(SystemTime::now) + } + /// Get the latest access timestamp, falling back to latest residence event. /// /// This function can only return `None` if there has not yet been a call to the @@ -271,7 +277,7 @@ impl LayerAccessStats { /// that that type can only be produced by inserting into the layer map. /// /// [`record_residence_event`]: Self::record_residence_event - pub(crate) fn latest_activity(&self) -> Option { + fn latest_activity(&self) -> Option { let locked = self.0.lock().unwrap(); let inner = &locked.for_eviction_policy; match inner.last_accesses.recent() { diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index dd9de99477..bfcc031863 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -1413,10 +1413,6 @@ impl ResidentLayer { &self.owner.0.path } - pub(crate) fn access_stats(&self) -> &LayerAccessStats { - self.owner.access_stats() - } - pub(crate) fn metadata(&self) -> LayerFileMetadata { self.owner.metadata() } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 74676277d5..625be7a644 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -12,6 +12,7 @@ use bytes::Bytes; use camino::{Utf8Path, Utf8PathBuf}; use enumset::EnumSet; use fail::fail_point; +use futures::stream::StreamExt; use itertools::Itertools; use pageserver_api::{ keyspace::{key_range_size, KeySpaceAccum}, @@ -105,7 +106,7 @@ use self::logical_size::LogicalSize; use self::walreceiver::{WalReceiver, WalReceiverConf}; use super::config::TenantConf; -use super::remote_timeline_client::index::{IndexLayerMetadata, IndexPart}; +use super::remote_timeline_client::index::IndexPart; use super::remote_timeline_client::RemoteTimelineClient; use super::secondary::heatmap::{HeatMapLayer, HeatMapTimeline}; use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf}; @@ -1458,7 +1459,7 @@ impl Timeline { generation, shard_identity, pg_version, - layers: Arc::new(tokio::sync::RwLock::new(LayerManager::create())), + layers: Default::default(), wanted_image_layers: Mutex::new(None), walredo_mgr, @@ -2283,45 +2284,28 @@ impl Timeline { /// should treat this as a cue to simply skip doing any heatmap uploading /// for this timeline. pub(crate) async fn generate_heatmap(&self) -> Option { - let eviction_info = self.get_local_layers_for_disk_usage_eviction().await; + // no point in heatmaps without remote client + let _remote_client = self.remote_client.as_ref()?; - let remote_client = match &self.remote_client { - Some(c) => c, - None => return None, - }; + if !self.is_active() { + return None; + } - let layer_file_names = eviction_info - .resident_layers - .iter() - .map(|l| l.layer.get_name()) - .collect::>(); + let guard = self.layers.read().await; - let decorated = match remote_client.get_layers_metadata(layer_file_names) { - Ok(d) => d, - Err(_) => { - // Getting metadata only fails on Timeline in bad state. - return None; - } - }; + let resident = guard.resident_layers().map(|layer| { + let last_activity_ts = layer.access_stats().latest_activity_or_now(); - let heatmap_layers = std::iter::zip( - eviction_info.resident_layers.into_iter(), - decorated.into_iter(), - ) - .filter_map(|(layer, remote_info)| { - remote_info.map(|remote_info| { - HeatMapLayer::new( - layer.layer.get_name(), - IndexLayerMetadata::from(remote_info), - layer.last_activity_ts, - ) - }) + HeatMapLayer::new( + layer.layer_desc().filename(), + layer.metadata().into(), + last_activity_ts, + ) }); - Some(HeatMapTimeline::new( - self.timeline_id, - heatmap_layers.collect(), - )) + let layers = resident.collect().await; + + Some(HeatMapTimeline::new(self.timeline_id, layers)) } } @@ -4662,41 +4646,24 @@ impl Timeline { /// Returns non-remote layers for eviction. pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo { let guard = self.layers.read().await; - let layers = guard.layer_map(); - let mut max_layer_size: Option = None; - let mut resident_layers = Vec::new(); - for l in layers.iter_historic_layers() { - let file_size = l.file_size(); - max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size))); + let resident_layers = guard + .resident_layers() + .map(|layer| { + let file_size = layer.layer_desc().file_size; + max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size))); - let l = guard.get_from_desc(&l); + let last_activity_ts = layer.access_stats().latest_activity_or_now(); - let l = match l.keep_resident().await { - Ok(Some(l)) => l, - Ok(None) => continue, - Err(e) => { - // these should not happen, but we cannot make them statically impossible right - // now. - tracing::warn!(layer=%l, "failed to keep the layer resident: {e:#}"); - continue; + EvictionCandidate { + layer: layer.into(), + last_activity_ts, + relative_last_activity: finite_f32::FiniteF32::ZERO, } - }; - - let last_activity_ts = l.access_stats().latest_activity().unwrap_or_else(|| { - // We only use this fallback if there's an implementation error. - // `latest_activity` already does rate-limited warn!() log. - debug!(layer=%l, "last_activity returns None, using SystemTime::now"); - SystemTime::now() - }); - - resident_layers.push(EvictionCandidate { - layer: l.drop_eviction_guard().into(), - last_activity_ts, - relative_last_activity: finite_f32::FiniteF32::ZERO, - }); - } + }) + .collect() + .await; DiskUsageEvictionInfo { max_layer_size, diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index 9bdd52e809..d87f78e35f 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -239,12 +239,7 @@ impl Timeline { } }; - let last_activity_ts = hist_layer.access_stats().latest_activity().unwrap_or_else(|| { - // We only use this fallback if there's an implementation error. - // `latest_activity` already does rate-limited warn!() log. - debug!(layer=%hist_layer, "last_activity returns None, using SystemTime::now"); - SystemTime::now() - }); + let last_activity_ts = hist_layer.access_stats().latest_activity_or_now(); let no_activity_for = match now.duration_since(last_activity_ts) { Ok(d) => d, diff --git a/pageserver/src/tenant/timeline/layer_manager.rs b/pageserver/src/tenant/timeline/layer_manager.rs index e38f5be209..ebcdcfdb4d 100644 --- a/pageserver/src/tenant/timeline/layer_manager.rs +++ b/pageserver/src/tenant/timeline/layer_manager.rs @@ -1,4 +1,5 @@ use anyhow::{bail, ensure, Context, Result}; +use futures::StreamExt; use pageserver_api::shard::TenantShardId; use std::{collections::HashMap, sync::Arc}; use tracing::trace; @@ -20,19 +21,13 @@ use crate::{ }; /// Provides semantic APIs to manipulate the layer map. +#[derive(Default)] pub(crate) struct LayerManager { layer_map: LayerMap, layer_fmgr: LayerFileManager, } impl LayerManager { - pub(crate) fn create() -> Self { - Self { - layer_map: LayerMap::default(), - layer_fmgr: LayerFileManager::new(), - } - } - pub(crate) fn get_from_desc(&self, desc: &PersistentLayerDesc) -> Layer { self.layer_fmgr.get_from_desc(desc) } @@ -246,6 +241,32 @@ impl LayerManager { layer.delete_on_drop(); } + pub(crate) fn resident_layers(&self) -> impl futures::stream::Stream + '_ { + // for small layer maps, we most likely have all resident, but for larger more are likely + // to be evicted assuming lots of layers correlated with longer lifespan. + + let layers = self + .layer_map() + .iter_historic_layers() + .map(|desc| self.get_from_desc(&desc)); + + let layers = futures::stream::iter(layers); + + layers.filter_map(|layer| async move { + // TODO(#6028): this query does not really need to see the ResidentLayer + match layer.keep_resident().await { + Ok(Some(layer)) => Some(layer.drop_eviction_guard()), + Ok(None) => None, + Err(e) => { + // these should not happen, but we cannot make them statically impossible right + // now. + tracing::warn!(%layer, "failed to keep the layer resident: {e:#}"); + None + } + } + }) + } + pub(crate) fn contains(&self, layer: &Layer) -> bool { self.layer_fmgr.contains(layer) } @@ -253,6 +274,12 @@ impl LayerManager { pub(crate) struct LayerFileManager(HashMap); +impl Default for LayerFileManager { + fn default() -> Self { + Self(HashMap::default()) + } +} + impl LayerFileManager { fn get_from_desc(&self, desc: &PersistentLayerDesc) -> T { // The assumption for the `expect()` is that all code maintains the following invariant: @@ -275,10 +302,6 @@ impl LayerFileManager { self.0.contains_key(&layer.layer_desc().key()) } - pub(crate) fn new() -> Self { - Self(HashMap::new()) - } - pub(crate) fn remove(&mut self, layer: &T) { let present = self.0.remove(&layer.layer_desc().key()); if present.is_none() && cfg!(debug_assertions) { From 8b8ff88e4b0e1a1b1c14f0edbe50e0c6236afa93 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 12 Feb 2024 16:25:33 +0100 Subject: [PATCH 53/81] GH actions: label to disable CI runs completely (#6677) I don't want my very-early-draft PRs to trigger any CI runs. So, add a label `run-no-ci`, and piggy-back on the `check-permissions` job. --- .github/workflows/actionlint.yml | 1 + .github/workflows/build_and_test.yml | 2 +- .github/workflows/neon_extra_builds.yml | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml index 584828c1d0..c290ff88e2 100644 --- a/.github/workflows/actionlint.yml +++ b/.github/workflows/actionlint.yml @@ -17,6 +17,7 @@ concurrency: jobs: actionlint: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 078916e1ea..6e4020a1b8 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -26,8 +26,8 @@ env: jobs: check-permissions: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: ubuntu-latest - steps: - name: Disallow PRs from forks if: | diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index c90ef60074..ff2a3a040a 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -117,6 +117,7 @@ jobs: check-linux-arm-build: timeout-minutes: 90 + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: [ self-hosted, dev, arm64 ] env: @@ -237,6 +238,7 @@ jobs: check-codestyle-rust-arm: timeout-minutes: 90 + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: [ self-hosted, dev, arm64 ] container: From a1f37cba1c790e5b89958fb7df13cde39429add8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Mon, 12 Feb 2024 19:15:21 +0100 Subject: [PATCH 54/81] Add test that runs the S3 scrubber (#6641) In #6079 it was found that there is no test that executes the scrubber. We now add such a test, which does the following things: * create a tenant, write some data * run the scrubber * remove the tenant * run the scrubber again Each time, the scrubber runs the scan-metadata command. Before #6079 we would have errored, now we don't. Fixes #6080 --- test_runner/fixtures/neon_fixtures.py | 8 ++-- .../regress/test_pageserver_generations.py | 4 +- .../regress/test_pageserver_secondary.py | 2 +- test_runner/regress/test_tenant_delete.py | 40 ++++++++++++++++++- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index faa8effe10..26f2b999a6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -899,7 +899,7 @@ class NeonEnvBuilder: if self.scrub_on_exit: try: - S3Scrubber(self.test_output_dir, self).scan_metadata() + S3Scrubber(self).scan_metadata() except Exception as e: log.error(f"Error during remote storage scrub: {e}") cleanup_error = e @@ -3659,9 +3659,9 @@ class SafekeeperHttpClient(requests.Session): class S3Scrubber: - def __init__(self, log_dir: Path, env: NeonEnvBuilder): + def __init__(self, env: NeonEnvBuilder, log_dir: Optional[Path] = None): self.env = env - self.log_dir = log_dir + self.log_dir = log_dir or env.test_output_dir def scrubber_cli(self, args: list[str], timeout) -> str: assert isinstance(self.env.pageserver_remote_storage, S3Storage) @@ -3682,7 +3682,7 @@ class S3Scrubber: args = base_args + args (output_path, stdout, status_code) = subprocess_capture( - self.log_dir, + self.env.test_output_dir, args, echo_stderr=True, echo_stdout=True, diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 725ed63d1c..de9f3b6945 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -265,9 +265,7 @@ def test_generations_upgrade(neon_env_builder: NeonEnvBuilder): # Having written a mixture of generation-aware and legacy index_part.json, # ensure the scrubber handles the situation as expected. - metadata_summary = S3Scrubber( - neon_env_builder.test_output_dir, neon_env_builder - ).scan_metadata() + metadata_summary = S3Scrubber(neon_env_builder).scan_metadata() assert metadata_summary["tenant_count"] == 1 # Scrubber should have seen our timeline assert metadata_summary["timeline_count"] == 1 assert metadata_summary["timeline_shard_count"] == 1 diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index 293152dd62..aec989252c 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -498,7 +498,7 @@ def test_secondary_downloads(neon_env_builder: NeonEnvBuilder): # Scrub the remote storage # ======================== # This confirms that the scrubber isn't upset by the presence of the heatmap - S3Scrubber(neon_env_builder.test_output_dir, neon_env_builder).scan_metadata() + S3Scrubber(neon_env_builder).scan_metadata() # Detach secondary and delete tenant # =================================== diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index b4e5a550f3..e928ea8bb1 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -9,6 +9,7 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, PgBin, + S3Scrubber, last_flush_lsn_upload, wait_for_last_flush_lsn, ) @@ -19,12 +20,13 @@ from fixtures.pageserver.utils import ( assert_prefix_not_empty, poll_for_remote_storage_iterations, tenant_delete_wait_completed, + wait_for_upload, wait_tenant_status_404, wait_until_tenant_active, wait_until_tenant_state, ) from fixtures.remote_storage import RemoteStorageKind, available_s3_storages, s3_storage -from fixtures.types import TenantId, TimelineId +from fixtures.types import Lsn, TenantId, TimelineId from fixtures.utils import run_pg_bench_small, wait_until from requests.exceptions import ReadTimeout @@ -669,3 +671,39 @@ def test_tenant_delete_races_timeline_creation( # Zero tenants remain (we deleted the default tenant) assert ps_http.get_metric_value("pageserver_tenant_manager_slots") == 0 + + +def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder): + """ + Validate that creating and then deleting the tenant both survives the scrubber, + and that one can run the scrubber without problems. + """ + + remote_storage_kind = RemoteStorageKind.MOCK_S3 + neon_env_builder.enable_pageserver_remote_storage(remote_storage_kind) + scrubber = S3Scrubber(neon_env_builder) + env = neon_env_builder.init_start(initial_tenant_conf=MANY_SMALL_LAYERS_TENANT_CONFIG) + + ps_http = env.pageserver.http_client() + # create a tenant separate from the main tenant so that we have one remaining + # after we deleted it, as the scrubber treats empty buckets as an error. + (tenant_id, timeline_id) = env.neon_cli.create_tenant() + + with env.endpoints.create_start("main", tenant_id=tenant_id) as endpoint: + run_pg_bench_small(pg_bin, endpoint.connstr()) + last_flush_lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0]) + ps_http.timeline_checkpoint(tenant_id, timeline_id) + wait_for_upload(ps_http, tenant_id, timeline_id, last_flush_lsn) + env.stop() + + result = scrubber.scan_metadata() + assert result["with_warnings"] == [] + + env.start() + ps_http = env.pageserver.http_client() + iterations = poll_for_remote_storage_iterations(remote_storage_kind) + tenant_delete_wait_completed(ps_http, tenant_id, iterations) + env.stop() + + scrubber.scan_metadata() + assert result["with_warnings"] == [] From fac50a6264fb8ee59778d0720ba799a24c46695a Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Mon, 12 Feb 2024 19:41:02 +0100 Subject: [PATCH 55/81] Proxy refactor auth+connect (#6708) ## Problem Not really a problem, just refactoring. ## Summary of changes Separate authenticate from wake compute. Do not call wake compute second time if we managed to connect to postgres or if we got it not from cache. --- proxy/src/auth.rs | 5 - proxy/src/auth/backend.rs | 146 ++++++++++++++++------------- proxy/src/auth/backend/classic.rs | 2 +- proxy/src/auth/backend/hacks.rs | 6 +- proxy/src/bin/proxy.rs | 2 +- proxy/src/compute.rs | 8 +- proxy/src/config.rs | 2 +- proxy/src/console/provider.rs | 33 ++++++- proxy/src/console/provider/mock.rs | 4 +- proxy/src/error.rs | 12 ++- proxy/src/proxy.rs | 13 +-- proxy/src/proxy/connect_compute.rs | 67 ++++++++----- proxy/src/proxy/tests.rs | 142 +++++++++++++++++++++------- proxy/src/proxy/wake_compute.rs | 16 +--- proxy/src/serverless/backend.rs | 40 +++----- 15 files changed, 307 insertions(+), 191 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 48de4e2353..c8028d1bf0 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -36,9 +36,6 @@ pub enum AuthErrorImpl { #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), - #[error(transparent)] - WakeCompute(#[from] console::errors::WakeComputeError), - /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -119,7 +116,6 @@ impl UserFacingError for AuthError { match self.0.as_ref() { Link(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), - WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), @@ -139,7 +135,6 @@ impl ReportableError for AuthError { match self.0.as_ref() { Link(e) => e.get_error_kind(), GetAuthInfo(e) => e.get_error_kind(), - WakeCompute(e) => e.get_error_kind(), Sasl(e) => e.get_error_kind(), AuthFailed(_) => crate::error::ErrorKind::User, BadAuthMethod(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index c9f21f1cf5..47c1dc4e92 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange; use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; -use crate::console::AuthSecret; +use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -26,7 +26,6 @@ use crate::{ stream, url, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; -use futures::TryFutureExt; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -56,11 +55,11 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T> { +pub enum BackendType<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>), + Link(MaybeOwned<'a, url::ApiUrl>, D), } pub trait TestBackend: Send + Sync + 'static { @@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, ()> { +impl std::fmt::Display for BackendType<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { @@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), } } } -impl BackendType<'_, T> { +impl BackendType<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T> { + pub fn as_ref(&self) -> BackendType<'_, &T, &D> { use BackendType::*; match self { Console(c, x) => Console(MaybeOwned::Borrowed(c), x), - Link(c) => Link(MaybeOwned::Borrowed(c)), + Link(c, x) => Link(MaybeOwned::Borrowed(c), x), } } } -impl<'a, T> BackendType<'a, T> { +impl<'a, T, D> BackendType<'a, T, D> { /// Very similar to [`std::option::Option::map`]. /// Maps [`BackendType`] to [`BackendType`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> { + pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), - Link(c) => Link(c), + Link(c, x) => Link(c, x), } } } - -impl<'a, T, E> BackendType<'a, Result> { +impl<'a, T, D, E> BackendType<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { + pub fn transpose(self) -> Result, E> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), - Link(c) => Ok(Link(c)), + Link(c, x) => Ok(Link(c, x)), } } } -pub struct ComputeCredentials { +pub struct ComputeCredentials { pub info: ComputeUserInfo, - pub keys: T, + pub keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] @@ -153,7 +151,6 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -188,7 +185,7 @@ async fn auth_quirks( client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -198,8 +195,11 @@ async fn auth_quirks( ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); - - (res.info, Some(res.keys)) + let password = match res.keys { + ComputeCredentialKeys::Password(p) => p, + _ => unreachable!("password hack should return a password"), + }; + (res.info, Some(password)) } Ok(info) => (info, None), }; @@ -253,7 +253,7 @@ async fn authenticate_with_secret( unauthenticated_password: Option>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { if let Some(password) = unauthenticated_password { let auth_outcome = validate_password_and_exchange(&password, secret)?; let keys = match auth_outcome { @@ -283,14 +283,14 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { +impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { Console(_, user_info) => user_info.endpoint_id.clone(), - Link(_) => Some("link".into()), + Link(_, _) => Some("link".into()), } } @@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, - Link(_) => "link", + Link(_, _) => "link", } } @@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, - ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { + ) -> auth::Result> { use BackendType::*; let res = match self { @@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let compute_credentials = + let credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - - let mut num_retries = 0; - let mut node = - wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - (node, BackendType::Console(api, compute_credentials.info)) + BackendType::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Link(url) => { + Link(url, _) => { info!("performing link authentication"); - let node_info = link::authenticate(ctx, &url, client).await?; + let info = link::authenticate(ctx, &url, client).await?; - ( - CachedNodeInfo::new_uncached(node_info), - BackendType::Link(url), - ) + BackendType::Link(url, info) } }; @@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } } -impl BackendType<'_, ComputeUserInfo> { +impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_role_secret( &self, ctx: &mut RequestMonitoring, @@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Link(_) => Ok(Cached::new_uncached(None)), + Link(_, _) => Ok(Cached::new_uncached(None)), } } @@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - } - } - - /// When applicable, wake the compute node, gaining its connection info in the process. - /// The link auth flow doesn't support this, so we return [`None`] in that case. - pub async fn wake_compute( - &self, - ctx: &mut RequestMonitoring, - ) -> Result, console::errors::WakeComputeError> { - use BackendType::*; - - match self { - Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, - Link(_) => Ok(None), + Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, info) => Ok(Cached::new_uncached(info.clone())), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index e855843bc3..d075331846 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,7 +17,7 @@ pub(super) async fn authenticate( client: &mut PqStream>, config: &'static AuthenticationConfig, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 9f60b709d4..26cf7a01f2 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -20,7 +20,7 @@ pub async fn authenticate_cleartext( info: ComputeUserInfo, client: &mut stream::PqStream>, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { warn!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -51,7 +51,7 @@ pub async fn password_hack_no_authentication( ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, -) -> auth::Result>> { +) -> auth::Result { warn!("project not specified, resorting to the password hack auth flow"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -73,6 +73,6 @@ pub async fn password_hack_no_authentication( options: info.options, endpoint: payload.endpoint, }, - keys: payload.password, + keys: ComputeCredentialKeys::Password(payload.password), }) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 8fbcb56758..00a229c135 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -383,7 +383,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url), ()) } }; let http_config = HttpConfig { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 83940d80ec..b61c1fb9ef 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,7 +1,7 @@ use crate::{ auth::parse_endpoint_param, cancellation::CancelClosure, - console::errors::WakeComputeError, + console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, error::{ReportableError, UserFacingError}, metrics::NUM_DB_CONNECTIONS_GAUGE, @@ -93,7 +93,7 @@ impl ConnCfg { } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: &Self) { + pub fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -253,6 +253,8 @@ pub struct PostgresConnection { pub params: std::collections::HashMap, /// Query cancellation token. pub cancel_closure: CancelClosure, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, _guage: IntCounterPairGuard, } @@ -263,6 +265,7 @@ impl ConnCfg { &self, ctx: &mut RequestMonitoring, allow_self_signed_compute: bool, + aux: MetricsAuxInfo, timeout: Duration, ) -> Result { let (socket_addr, stream, host) = self.connect_raw(timeout).await?; @@ -297,6 +300,7 @@ impl ConnCfg { stream, params, cancel_closure, + aux, _guage: NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 31c9228b35..5fcb537834 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -13,7 +13,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, ()>, + pub auth_backend: auth::BackendType<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index e5cad42753..640444d14e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,10 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::{backend::ComputeUserInfo, IpPattern}, + auth::{ + backend::{ComputeCredentialKeys, ComputeUserInfo}, + IpPattern, + }, cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, config::{CacheOptions, ProjectInfoCacheOptions}, @@ -261,6 +264,34 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } +impl NodeInfo { + pub async fn connect( + &self, + ctx: &mut RequestMonitoring, + timeout: Duration, + ) -> Result { + self.config + .connect( + ctx, + self.allow_self_signed_compute, + self.aux.clone(), + timeout, + ) + .await + } + pub fn reuse_settings(&mut self, other: Self) { + self.allow_self_signed_compute = other.allow_self_signed_compute; + self.config.reuse_password(other.config); + } + + pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + match keys { + ComputeCredentialKeys::Password(password) => self.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + }; + } +} + pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 79a04f255d..0579ef6fc4 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -176,9 +176,7 @@ impl super::Api for Api { _ctx: &mut RequestMonitoring, _user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute() - .map_ok(CachedNodeInfo::new_uncached) - .await + self.do_wake_compute().map_ok(Cached::new_uncached).await } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index eafe92bf48..69fe1ebc12 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, @@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed { ErrorKind::RateLimit } } + +impl ReportableError for tokio_postgres::error::Error { + fn get_error_kind(&self) -> ErrorKind { + if self.as_db_error().is_some() { + ErrorKind::Postgres + } else { + ErrorKind::Compute + } + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 77aadb6f28..5f65de4c98 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -163,14 +163,14 @@ pub enum ClientMode { /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - fn allow_cleartext(&self) -> bool { + pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -287,7 +287,7 @@ pub async fn handle_client( } let user = user_info.get_user().to_owned(); - let (mut node_info, user_info) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -306,14 +306,11 @@ pub async fn handle_client( } }; - node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config); - - let aux = node_info.aux.clone(); let mut node = connect_to_compute( ctx, &TcpMechanism { params: ¶ms }, - node_info, &user_info, + mode.allow_self_signed_compute(config), ) .or_else(|e| stream.throw_error(e)) .await?; @@ -330,8 +327,8 @@ pub async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, + aux: node.aux.clone(), compute: node, - aux, req: _request_gauge, conn: _client_gauge, })) diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index b9346aa743..6e57caf998 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,8 +1,9 @@ use crate::{ - auth, + auth::backend::ComputeCredentialKeys, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError}, + console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + error::ReportableError, metrics::NUM_CONNECTION_FAILURES, proxy::{ retry::{retry_after, ShouldRetry}, @@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg { +pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg }; NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); - node_info.invalidate().config + node_info.invalidate() } #[async_trait] pub trait ConnectMechanism { type Connection; - type ConnectError; + type ConnectError: ReportableError; type Error: From; async fn connect_once( &self, @@ -49,6 +50,16 @@ pub trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } +#[async_trait] +pub trait ComputeConnectBackend { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result; + + fn get_keys(&self) -> Option<&ComputeCredentialKeys>; +} + pub struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. pub params: &'a StartupMessageParams, @@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await + node_info.connect(ctx, timeout).await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, - mut node_info: console::CachedNodeInfo, - user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, + user_info: &B, + allow_self_signed_compute: bool, ) -> Result where M::ConnectError: ShouldRetry + std::fmt::Debug, M::Error: From, { + let mut num_retries = 0; + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + if let Some(keys) = user_info.get_keys() { + node_info.set_keys(keys); + } + node_info.allow_self_signed_compute = allow_self_signed_compute; + // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); // try once @@ -108,28 +122,31 @@ where error!(error = ?err, "could not connect to compute node"); - let mut num_retries = 1; - - match user_info { - auth::BackendType::Console(api, info) => { + let node_info = + if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() { + // If the error is Postgres, that means that we managed to connect to the compute node, but there was an error. + // Do not need to retrieve a new node_info, just return the old one. + if !err.should_retry(num_retries) { + return Err(err.into()); + } + node_info + } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node info!("compute node's state has likely changed; requesting a wake-up"); - ctx.latency_timer.cache_miss(); - let config = invalidate_cache(node_info); - node_info = wake_compute(&mut num_retries, ctx, api, info).await?; + let old_node_info = invalidate_cache(node_info); + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + node_info.reuse_settings(old_node_info); - node_info.config.reuse_password(&config); mechanism.update_connect_config(&mut node_info.config); - } - // nothing to do? - auth::BackendType::Link(_) => {} - }; + node_info + }; // now that we have a new node, try connect to it repeatedly. // this can error for a few reasons, for instance: // * DNS connection settings haven't quite propagated yet info!("wake_compute success. attempting to connect"); + num_retries = 1; loop { match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 5bb43c0375..efbd661bbf 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -2,13 +2,19 @@ mod mitm; +use std::time::Duration; + use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; +use crate::auth::backend::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, +}; use crate::config::CertResolver; +use crate::console::caches::NodeInfoCache; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::error::ErrorKind; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; @@ -369,12 +375,15 @@ enum ConnectAction { Connect, Retry, Fail, + RetryPg, + FailPg, } #[derive(Clone)] struct TestConnectMechanism { counter: Arc>, sequence: Vec, + cache: &'static NodeInfoCache, } impl TestConnectMechanism { @@ -393,6 +402,12 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, + cache: Box::leak(Box::new(NodeInfoCache::new( + "test", + 1, + Duration::from_secs(100), + false, + ))), } } } @@ -403,6 +418,13 @@ struct TestConnection; #[derive(Debug)] struct TestConnectError { retryable: bool, + kind: crate::error::ErrorKind, +} + +impl ReportableError for TestConnectError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + self.kind + } } impl std::fmt::Display for TestConnectError { @@ -436,8 +458,22 @@ impl ConnectMechanism for TestConnectMechanism { *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { retryable: true }), - ConnectAction::Fail => Err(TestConnectError { retryable: false }), + ConnectAction::Retry => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Compute, + }), + ConnectAction::Fail => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Compute, + }), + ConnectAction::FailPg => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Postgres, + }), + ConnectAction::RetryPg => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Postgres, + }), x => panic!("expecting action {:?}, connect is called instead", x), } } @@ -451,7 +487,7 @@ impl TestBackend for TestConnectMechanism { let action = self.sequence[*counter]; *counter += 1; match action { - ConnectAction::Wake => Ok(helper_create_cached_node_info()), + ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console { status: http::StatusCode::FORBIDDEN, @@ -483,37 +519,41 @@ impl TestBackend for TestConnectMechanism { } } -fn helper_create_cached_node_info() -> CachedNodeInfo { +fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { config: compute::ConnCfg::new(), aux: Default::default(), allow_self_signed_compute: false, }; - CachedNodeInfo::new_uncached(node) + let (_, node) = cache.insert("key".into(), node); + node } fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { - let cache = helper_create_cached_node_info(); +) -> auth::BackendType<'static, ComputeCredentials, &()> { let user_info = auth::BackendType::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), - ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), + ComputeCredentials { + info: ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + keys: ComputeCredentialKeys::Password("password".into()), }, ); - (cache, user_info) + user_info } #[tokio::test] async fn connect_to_compute_success() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -521,24 +561,52 @@ async fn connect_to_compute_success() { #[tokio::test] async fn connect_to_compute_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); } +#[tokio::test] +async fn connect_to_compute_retry_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap(); + mechanism.verify(); +} + +#[tokio::test] +async fn connect_to_compute_fail_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap_err(); + mechanism.verify(); +} + /// Test that we don't retry if the error is not retryable. #[tokio::test] async fn connect_to_compute_non_retry_1() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -547,11 +615,12 @@ async fn connect_to_compute_non_retry_1() { /// Even for non-retryable errors, we should retry at least once. #[tokio::test] async fn connect_to_compute_non_retry_2() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -560,15 +629,16 @@ async fn connect_to_compute_non_retry_2() { /// Retry for at most `NUM_RETRIES_CONNECT` times. #[tokio::test] async fn connect_to_compute_non_retry_3() { + let _ = env_logger::try_init(); assert_eq!(NUM_RETRIES_CONNECT, 16); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![ - Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, - Retry, Retry, Retry, Retry, /* the 17th time */ Retry, + Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, + Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry, ]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -577,11 +647,12 @@ async fn connect_to_compute_non_retry_3() { /// Should retry wake compute. #[tokio::test] async fn wake_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -590,11 +661,12 @@ async fn wake_retry() { /// Wake failed with a non-retryable error. #[tokio::test] async fn wake_non_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 925727bdab..2c593451b4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,4 @@ -use crate::auth::backend::ComputeUserInfo; -use crate::console::{ - errors::WakeComputeError, - provider::{CachedNodeInfo, ConsoleBackend}, - Api, -}; +use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; use crate::proxy::retry::retry_after; @@ -11,17 +6,16 @@ use hyper::StatusCode; use std::ops::ControlFlow; use tracing::{error, warn}; +use super::connect_compute::ComputeConnectBackend; use super::retry::ShouldRetry; -/// wake a compute (or retrieve an existing compute session from cache) -pub async fn wake_compute( +pub async fn wake_compute( num_retries: &mut u32, ctx: &mut RequestMonitoring, - api: &ConsoleBackend, - info: &ComputeUserInfo, + api: &B, ) -> Result { loop { - let wake_res = api.wake_compute(ctx, info).await; + let wake_res = api.wake_compute(ctx).await; match handle_try_wake(wake_res, *num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 156002006d..6f93f86d5f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, console::{ @@ -27,7 +27,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> Result { + ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; @@ -49,13 +49,17 @@ impl PoolingBackend { }; let auth_outcome = crate::auth::validate_password_and_exchange(&conn_info.password, secret)?; - match auth_outcome { + let res = match auth_outcome { crate::sasl::Outcome::Success(key) => Ok(key), crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); Err(AuthError::auth_failed(&*conn_info.user_info.user)) } - } + }; + res.map(|key| ComputeCredentials { + info: user_info, + keys: key, + }) } // Wake up the destination if needed. Code here is a bit involved because @@ -66,7 +70,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: ConnInfo, - keys: ComputeCredentialKeys, + keys: ComputeCredentials, force_new: bool, ) -> Result, HttpConnError> { let maybe_client = if !force_new { @@ -82,26 +86,8 @@ impl PoolingBackend { } let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); - info!("pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - let mut node_info = backend - .wake_compute(ctx) - .await? - .ok_or(HttpConnError::NoComputeInfo)?; - - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node_info.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), - }; - - ctx.set_project(node_info.aux.clone()); - + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + let backend = self.config.auth_backend.as_ref().map(|_| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -109,8 +95,8 @@ impl PoolingBackend { conn_info, pool: self.pool.clone(), }, - node_info, &backend, + false, // do not allow self signed compute for http flow ) .await } @@ -129,8 +115,6 @@ pub enum HttpConnError { AuthError(#[from] AuthError), #[error("wake_compute returned error")] WakeCompute(#[from] WakeComputeError), - #[error("wake_compute returned nothing")] - NoComputeInfo, } struct TokioMechanism { From 4be2223a4cd80fdc40c37aab2206bb6f505dc008 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Mon, 12 Feb 2024 20:29:57 +0000 Subject: [PATCH 56/81] Discrete event simulation for safekeepers (#5804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR contains the first version of a [FoundationDB-like](https://www.youtube.com/watch?v=4fFDFbi3toc) simulation testing for safekeeper and walproposer. ### desim This is a core "framework" for running determenistic simulation. It operates on threads, allowing to test syncronous code (like walproposer). `libs/desim/src/executor.rs` contains implementation of a determenistic thread execution. This is achieved by blocking all threads, and each time allowing only a single thread to make an execution step. All executor's threads are blocked using `yield_me(after_ms)` function. This function is called when a thread wants to sleep or wait for an external notification (like blocking on a channel until it has a ready message). `libs/desim/src/chan.rs` contains implementation of a channel (basic sync primitive). It has unlimited capacity and any thread can push or read messages to/from it. `libs/desim/src/network.rs` has a very naive implementation of a network (only reliable TCP-like connections are supported for now), that can have arbitrary delays for each package and failure injections for breaking connections with some probability. `libs/desim/src/world.rs` ties everything together, to have a concept of virtual nodes that can have network connections between them. ### walproposer_sim Has everything to run walproposer and safekeepers in a simulation. `safekeeper.rs` reimplements all necesary stuff from `receive_wal.rs`, `send_wal.rs` and `timelines_global_map.rs`. `walproposer_api.rs` implements all walproposer callback to use simulation library. `simulation.rs` defines a schedule – a set of events like `restart ` or `write_wal` that should happen at time ``. It also has code to spawn walproposer/safekeeper threads and provide config to them. ### tests `simple_test.rs` has tests that just start walproposer and 3 safekeepers together in a simulation, and tests that they are not crashing right away. `misc_test.rs` has tests checking more advanced simulation cases, like crashing or restarting threads, testing memory deallocation, etc. `random_test.rs` is the main test, it checks thousands of random seeds (schedules) for correctness. It roughly corresponds to running a real python integration test in an environment with very unstable network and cpu, but in a determenistic way (each seed results in the same execution log) and much much faster. Closes #547 --------- Co-authored-by: Arseny Sher --- Cargo.lock | 20 + Cargo.toml | 2 + libs/desim/Cargo.toml | 18 + libs/desim/README.md | 7 + libs/desim/src/chan.rs | 108 +++ libs/desim/src/executor.rs | 483 +++++++++++++ libs/desim/src/lib.rs | 8 + libs/desim/src/network.rs | 451 ++++++++++++ libs/desim/src/node_os.rs | 54 ++ libs/desim/src/options.rs | 50 ++ libs/desim/src/proto.rs | 63 ++ libs/desim/src/time.rs | 129 ++++ libs/desim/src/world.rs | 180 +++++ libs/desim/tests/reliable_copy_test.rs | 244 +++++++ libs/postgres_ffi/src/xlog_utils.rs | 10 +- libs/walproposer/build.rs | 4 + libs/walproposer/src/api_bindings.rs | 20 +- libs/walproposer/src/walproposer.rs | 45 +- pageserver/src/walingest.rs | 2 +- pgxn/neon/walproposer.c | 15 +- pgxn/neon/walproposer.h | 9 + safekeeper/Cargo.toml | 7 + safekeeper/tests/misc_test.rs | 155 ++++ safekeeper/tests/random_test.rs | 56 ++ safekeeper/tests/simple_test.rs | 45 ++ .../tests/walproposer_sim/block_storage.rs | 57 ++ safekeeper/tests/walproposer_sim/log.rs | 77 ++ safekeeper/tests/walproposer_sim/mod.rs | 8 + .../tests/walproposer_sim/safekeeper.rs | 410 +++++++++++ .../tests/walproposer_sim/safekeeper_disk.rs | 278 +++++++ .../tests/walproposer_sim/simulation.rs | 436 +++++++++++ .../tests/walproposer_sim/simulation_logs.rs | 187 +++++ .../tests/walproposer_sim/walproposer_api.rs | 676 ++++++++++++++++++ .../tests/walproposer_sim/walproposer_disk.rs | 314 ++++++++ 34 files changed, 4603 insertions(+), 25 deletions(-) create mode 100644 libs/desim/Cargo.toml create mode 100644 libs/desim/README.md create mode 100644 libs/desim/src/chan.rs create mode 100644 libs/desim/src/executor.rs create mode 100644 libs/desim/src/lib.rs create mode 100644 libs/desim/src/network.rs create mode 100644 libs/desim/src/node_os.rs create mode 100644 libs/desim/src/options.rs create mode 100644 libs/desim/src/proto.rs create mode 100644 libs/desim/src/time.rs create mode 100644 libs/desim/src/world.rs create mode 100644 libs/desim/tests/reliable_copy_test.rs create mode 100644 safekeeper/tests/misc_test.rs create mode 100644 safekeeper/tests/random_test.rs create mode 100644 safekeeper/tests/simple_test.rs create mode 100644 safekeeper/tests/walproposer_sim/block_storage.rs create mode 100644 safekeeper/tests/walproposer_sim/log.rs create mode 100644 safekeeper/tests/walproposer_sim/mod.rs create mode 100644 safekeeper/tests/walproposer_sim/safekeeper.rs create mode 100644 safekeeper/tests/walproposer_sim/safekeeper_disk.rs create mode 100644 safekeeper/tests/walproposer_sim/simulation.rs create mode 100644 safekeeper/tests/walproposer_sim/simulation_logs.rs create mode 100644 safekeeper/tests/walproposer_sim/walproposer_api.rs create mode 100644 safekeeper/tests/walproposer_sim/walproposer_disk.rs diff --git a/Cargo.lock b/Cargo.lock index 520163e41b..f11c774016 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1639,6 +1639,22 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "desim" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "hex", + "parking_lot 0.12.1", + "rand 0.8.5", + "scopeguard", + "smallvec", + "tracing", + "utils", + "workspace_hack", +] + [[package]] name = "diesel" version = "2.1.4" @@ -4827,6 +4843,7 @@ dependencies = [ "clap", "const_format", "crc32c", + "desim", "fail", "fs2", "futures", @@ -4842,6 +4859,7 @@ dependencies = [ "postgres_backend", "postgres_ffi", "pq_proto", + "rand 0.8.5", "regex", "remote_storage", "reqwest", @@ -4862,8 +4880,10 @@ dependencies = [ "tokio-util", "toml_edit", "tracing", + "tracing-subscriber", "url", "utils", + "walproposer", "workspace_hack", ] diff --git a/Cargo.toml b/Cargo.toml index ebc3dfa7b1..8df9ca9988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "libs/pageserver_api", "libs/postgres_ffi", "libs/safekeeper_api", + "libs/desim", "libs/utils", "libs/consumption_metrics", "libs/postgres_backend", @@ -203,6 +204,7 @@ postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } +desim = { version = "0.1", path = "./libs/desim" } storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy. tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" } tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" } diff --git a/libs/desim/Cargo.toml b/libs/desim/Cargo.toml new file mode 100644 index 0000000000..6f442d8243 --- /dev/null +++ b/libs/desim/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "desim" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +rand.workspace = true +tracing.workspace = true +bytes.workspace = true +utils.workspace = true +parking_lot.workspace = true +hex.workspace = true +scopeguard.workspace = true +smallvec = { workspace = true, features = ["write"] } + +workspace_hack.workspace = true diff --git a/libs/desim/README.md b/libs/desim/README.md new file mode 100644 index 0000000000..80568ebb1b --- /dev/null +++ b/libs/desim/README.md @@ -0,0 +1,7 @@ +# Discrete Event SIMulator + +This is a library for running simulations of distributed systems. The main idea is borrowed from [FoundationDB](https://www.youtube.com/watch?v=4fFDFbi3toc). + +Each node runs as a separate thread. This library was not optimized for speed yet, but it's already much faster than running usual intergration tests in real time, because it uses virtual simulation time and can fast-forward time to skip intervals where all nodes are doing nothing but sleeping or waiting for something. + +The original purpose for this library is to test walproposer and safekeeper implementation working together, in a scenarios close to the real world environment. This simulator is determenistic and can inject failures in networking without waiting minutes of wall-time to trigger timeout, which makes it easier to find bugs in our consensus implementation compared to using integration tests. diff --git a/libs/desim/src/chan.rs b/libs/desim/src/chan.rs new file mode 100644 index 0000000000..6661d59871 --- /dev/null +++ b/libs/desim/src/chan.rs @@ -0,0 +1,108 @@ +use std::{collections::VecDeque, sync::Arc}; + +use parking_lot::{Mutex, MutexGuard}; + +use crate::executor::{self, PollSome, Waker}; + +/// FIFO channel with blocking send and receive. Can be cloned and shared between threads. +/// Blocking functions should be used only from threads that are managed by the executor. +pub struct Chan { + shared: Arc>, +} + +impl Clone for Chan { + fn clone(&self) -> Self { + Chan { + shared: self.shared.clone(), + } + } +} + +impl Default for Chan { + fn default() -> Self { + Self::new() + } +} + +impl Chan { + pub fn new() -> Chan { + Chan { + shared: Arc::new(State { + queue: Mutex::new(VecDeque::new()), + waker: Waker::new(), + }), + } + } + + /// Get a message from the front of the queue, block if the queue is empty. + /// If not called from the executor thread, it can block forever. + pub fn recv(&self) -> T { + self.shared.recv() + } + + /// Panic if the queue is empty. + pub fn must_recv(&self) -> T { + self.shared + .try_recv() + .expect("message should've been ready") + } + + /// Get a message from the front of the queue, return None if the queue is empty. + /// Never blocks. + pub fn try_recv(&self) -> Option { + self.shared.try_recv() + } + + /// Send a message to the back of the queue. + pub fn send(&self, t: T) { + self.shared.send(t); + } +} + +struct State { + queue: Mutex>, + waker: Waker, +} + +impl State { + fn send(&self, t: T) { + self.queue.lock().push_back(t); + self.waker.wake_all(); + } + + fn try_recv(&self) -> Option { + let mut q = self.queue.lock(); + q.pop_front() + } + + fn recv(&self) -> T { + // interrupt the receiver to prevent consuming everything at once + executor::yield_me(0); + + let mut queue = self.queue.lock(); + if let Some(t) = queue.pop_front() { + return t; + } + loop { + self.waker.wake_me_later(); + if let Some(t) = queue.pop_front() { + return t; + } + MutexGuard::unlocked(&mut queue, || { + executor::yield_me(-1); + }); + } + } +} + +impl PollSome for Chan { + /// Schedules a wakeup for the current thread. + fn wake_me(&self) { + self.shared.waker.wake_me_later(); + } + + /// Checks if chan has any pending messages. + fn has_some(&self) -> bool { + !self.shared.queue.lock().is_empty() + } +} diff --git a/libs/desim/src/executor.rs b/libs/desim/src/executor.rs new file mode 100644 index 0000000000..9d44bd7741 --- /dev/null +++ b/libs/desim/src/executor.rs @@ -0,0 +1,483 @@ +use std::{ + panic::AssertUnwindSafe, + sync::{ + atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering}, + mpsc, Arc, OnceLock, + }, + thread::JoinHandle, +}; + +use tracing::{debug, error, trace}; + +use crate::time::Timing; + +/// Stores status of the running threads. Threads are registered in the runtime upon creation +/// and deregistered upon termination. +pub struct Runtime { + // stores handles to all threads that are currently running + threads: Vec, + // stores current time and pending wakeups + clock: Arc, + // thread counter + thread_counter: AtomicU32, + // Thread step counter -- how many times all threads has been actually + // stepped (note that all world/time/executor/thread have slightly different + // meaning of steps). For observability. + pub step_counter: u64, +} + +impl Runtime { + /// Init new runtime, no running threads. + pub fn new(clock: Arc) -> Self { + Self { + threads: Vec::new(), + clock, + thread_counter: AtomicU32::new(0), + step_counter: 0, + } + } + + /// Spawn a new thread and register it in the runtime. + pub fn spawn(&mut self, f: F) -> ExternalHandle + where + F: FnOnce() + Send + 'static, + { + let (tx, rx) = mpsc::channel(); + + let clock = self.clock.clone(); + let tid = self.thread_counter.fetch_add(1, Ordering::SeqCst); + debug!("spawning thread-{}", tid); + + let join = std::thread::spawn(move || { + let _guard = tracing::info_span!("", tid).entered(); + + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + with_thread_context(|ctx| { + assert!(ctx.clock.set(clock).is_ok()); + ctx.id.store(tid, Ordering::SeqCst); + tx.send(ctx.clone()).expect("failed to send thread context"); + // suspend thread to put it to `threads` in sleeping state + ctx.yield_me(0); + }); + + // start user-provided function + f(); + })); + debug!("thread finished"); + + if let Err(e) = res { + with_thread_context(|ctx| { + if !ctx.allow_panic.load(std::sync::atomic::Ordering::SeqCst) { + error!("thread panicked, terminating the process: {:?}", e); + std::process::exit(1); + } + + debug!("thread panicked: {:?}", e); + let mut result = ctx.result.lock(); + if result.0 == -1 { + *result = (256, format!("thread panicked: {:?}", e)); + } + }); + } + + with_thread_context(|ctx| { + ctx.finish_me(); + }); + }); + + let ctx = rx.recv().expect("failed to receive thread context"); + let handle = ThreadHandle::new(ctx.clone(), join); + + self.threads.push(handle); + + ExternalHandle { ctx } + } + + /// Returns true if there are any unfinished activity, such as running thread or pending events. + /// Otherwise returns false, which means all threads are blocked forever. + pub fn step(&mut self) -> bool { + trace!("runtime step"); + + // have we run any thread? + let mut ran = false; + + self.threads.retain(|thread: &ThreadHandle| { + let res = thread.ctx.wakeup.compare_exchange( + PENDING_WAKEUP, + NO_WAKEUP, + Ordering::SeqCst, + Ordering::SeqCst, + ); + if res.is_err() { + // thread has no pending wakeups, leaving as is + return true; + } + ran = true; + + trace!("entering thread-{}", thread.ctx.tid()); + let status = thread.step(); + self.step_counter += 1; + trace!( + "out of thread-{} with status {:?}", + thread.ctx.tid(), + status + ); + + if status == Status::Sleep { + true + } else { + trace!("thread has finished"); + // removing the thread from the list + false + } + }); + + if !ran { + trace!("no threads were run, stepping clock"); + if let Some(ctx_to_wake) = self.clock.step() { + trace!("waking up thread-{}", ctx_to_wake.tid()); + ctx_to_wake.inc_wake(); + } else { + return false; + } + } + + true + } + + /// Kill all threads. This is done by setting a flag in each thread context and waking it up. + pub fn crash_all_threads(&mut self) { + for thread in self.threads.iter() { + thread.ctx.crash_stop(); + } + + // all threads should be finished after a few steps + while !self.threads.is_empty() { + self.step(); + } + } +} + +impl Drop for Runtime { + fn drop(&mut self) { + debug!("dropping the runtime"); + self.crash_all_threads(); + } +} + +#[derive(Clone)] +pub struct ExternalHandle { + ctx: Arc, +} + +impl ExternalHandle { + /// Returns true if thread has finished execution. + pub fn is_finished(&self) -> bool { + let status = self.ctx.mutex.lock(); + *status == Status::Finished + } + + /// Returns exitcode and message, which is available after thread has finished execution. + pub fn result(&self) -> (i32, String) { + let result = self.ctx.result.lock(); + result.clone() + } + + /// Returns thread id. + pub fn id(&self) -> u32 { + self.ctx.id.load(Ordering::SeqCst) + } + + /// Sets a flag to crash thread on the next wakeup. + pub fn crash_stop(&self) { + self.ctx.crash_stop(); + } +} + +struct ThreadHandle { + ctx: Arc, + _join: JoinHandle<()>, +} + +impl ThreadHandle { + /// Create a new [`ThreadHandle`] and wait until thread will enter [`Status::Sleep`] state. + fn new(ctx: Arc, join: JoinHandle<()>) -> Self { + let mut status = ctx.mutex.lock(); + // wait until thread will go into the first yield + while *status != Status::Sleep { + ctx.condvar.wait(&mut status); + } + drop(status); + + Self { ctx, _join: join } + } + + /// Allows thread to execute one step of its execution. + /// Returns [`Status`] of the thread after the step. + fn step(&self) -> Status { + let mut status = self.ctx.mutex.lock(); + assert!(matches!(*status, Status::Sleep)); + + *status = Status::Running; + self.ctx.condvar.notify_all(); + + while *status == Status::Running { + self.ctx.condvar.wait(&mut status); + } + + *status + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Status { + /// Thread is running. + Running, + /// Waiting for event to complete, will be resumed by the executor step, once wakeup flag is set. + Sleep, + /// Thread finished execution. + Finished, +} + +const NO_WAKEUP: u8 = 0; +const PENDING_WAKEUP: u8 = 1; + +pub struct ThreadContext { + id: AtomicU32, + // used to block thread until it is woken up + mutex: parking_lot::Mutex, + condvar: parking_lot::Condvar, + // used as a flag to indicate runtime that thread is ready to be woken up + wakeup: AtomicU8, + clock: OnceLock>, + // execution result, set by exit() call + result: parking_lot::Mutex<(i32, String)>, + // determines if process should be killed on receiving panic + allow_panic: AtomicBool, + // acts as a signal that thread should crash itself on the next wakeup + crash_request: AtomicBool, +} + +impl ThreadContext { + pub(crate) fn new() -> Self { + Self { + id: AtomicU32::new(0), + mutex: parking_lot::Mutex::new(Status::Running), + condvar: parking_lot::Condvar::new(), + wakeup: AtomicU8::new(NO_WAKEUP), + clock: OnceLock::new(), + result: parking_lot::Mutex::new((-1, String::new())), + allow_panic: AtomicBool::new(false), + crash_request: AtomicBool::new(false), + } + } +} + +// Functions for executor to control thread execution. +impl ThreadContext { + /// Set atomic flag to indicate that thread is ready to be woken up. + fn inc_wake(&self) { + self.wakeup.store(PENDING_WAKEUP, Ordering::SeqCst); + } + + /// Internal function used for event queues. + pub(crate) fn schedule_wakeup(self: &Arc, after_ms: u64) { + self.clock + .get() + .unwrap() + .schedule_wakeup(after_ms, self.clone()); + } + + fn tid(&self) -> u32 { + self.id.load(Ordering::SeqCst) + } + + fn crash_stop(&self) { + let status = self.mutex.lock(); + if *status == Status::Finished { + debug!( + "trying to crash thread-{}, which is already finished", + self.tid() + ); + return; + } + assert!(matches!(*status, Status::Sleep)); + drop(status); + + self.allow_panic.store(true, Ordering::SeqCst); + self.crash_request.store(true, Ordering::SeqCst); + // set a wakeup + self.inc_wake(); + // it will panic on the next wakeup + } +} + +// Internal functions. +impl ThreadContext { + /// Blocks thread until it's woken up by the executor. If `after_ms` is 0, is will be + /// woken on the next step. If `after_ms` > 0, wakeup is scheduled after that time. + /// Otherwise wakeup is not scheduled inside `yield_me`, and should be arranged before + /// calling this function. + fn yield_me(self: &Arc, after_ms: i64) { + let mut status = self.mutex.lock(); + assert!(matches!(*status, Status::Running)); + + match after_ms.cmp(&0) { + std::cmp::Ordering::Less => { + // block until something wakes us up + } + std::cmp::Ordering::Equal => { + // tell executor that we are ready to be woken up + self.inc_wake(); + } + std::cmp::Ordering::Greater => { + // schedule wakeup + self.clock + .get() + .unwrap() + .schedule_wakeup(after_ms as u64, self.clone()); + } + } + + *status = Status::Sleep; + self.condvar.notify_all(); + + // wait until executor wakes us up + while *status != Status::Running { + self.condvar.wait(&mut status); + } + + if self.crash_request.load(Ordering::SeqCst) { + panic!("crashed by request"); + } + } + + /// Called only once, exactly before thread finishes execution. + fn finish_me(&self) { + let mut status = self.mutex.lock(); + assert!(matches!(*status, Status::Running)); + + *status = Status::Finished; + { + let mut result = self.result.lock(); + if result.0 == -1 { + *result = (0, "finished normally".to_owned()); + } + } + self.condvar.notify_all(); + } +} + +/// Invokes the given closure with a reference to the current thread [`ThreadContext`]. +#[inline(always)] +fn with_thread_context(f: impl FnOnce(&Arc) -> T) -> T { + thread_local!(static THREAD_DATA: Arc = Arc::new(ThreadContext::new())); + THREAD_DATA.with(f) +} + +/// Waker is used to wake up threads that are blocked on condition. +/// It keeps track of contexts [`Arc`] and can increment the counter +/// of several contexts to send a notification. +pub struct Waker { + // contexts that are waiting for a notification + contexts: parking_lot::Mutex; 8]>>, +} + +impl Default for Waker { + fn default() -> Self { + Self::new() + } +} + +impl Waker { + pub fn new() -> Self { + Self { + contexts: parking_lot::Mutex::new(smallvec::SmallVec::new()), + } + } + + /// Subscribe current thread to receive a wake notification later. + pub fn wake_me_later(&self) { + with_thread_context(|ctx| { + self.contexts.lock().push(ctx.clone()); + }); + } + + /// Wake up all threads that are waiting for a notification and clear the list. + pub fn wake_all(&self) { + let mut v = self.contexts.lock(); + for ctx in v.iter() { + ctx.inc_wake(); + } + v.clear(); + } +} + +/// See [`ThreadContext::yield_me`]. +pub fn yield_me(after_ms: i64) { + with_thread_context(|ctx| ctx.yield_me(after_ms)) +} + +/// Get current time. +pub fn now() -> u64 { + with_thread_context(|ctx| ctx.clock.get().unwrap().now()) +} + +pub fn exit(code: i32, msg: String) { + with_thread_context(|ctx| { + ctx.allow_panic.store(true, Ordering::SeqCst); + let mut result = ctx.result.lock(); + *result = (code, msg); + panic!("exit"); + }); +} + +pub(crate) fn get_thread_ctx() -> Arc { + with_thread_context(|ctx| ctx.clone()) +} + +/// Trait for polling channels until they have something. +pub trait PollSome { + /// Schedule wakeup for message arrival. + fn wake_me(&self); + + /// Check if channel has a ready message. + fn has_some(&self) -> bool; +} + +/// Blocks current thread until one of the channels has a ready message. Returns +/// index of the channel that has a message. If timeout is reached, returns None. +/// +/// Negative timeout means block forever. Zero timeout means check channels and return +/// immediately. Positive timeout means block until timeout is reached. +pub fn epoll_chans(chans: &[Box], timeout: i64) -> Option { + let deadline = if timeout < 0 { + 0 + } else { + now() + timeout as u64 + }; + + loop { + for chan in chans { + chan.wake_me() + } + + for (i, chan) in chans.iter().enumerate() { + if chan.has_some() { + return Some(i); + } + } + + if timeout < 0 { + // block until wakeup + yield_me(-1); + } else { + let current_time = now(); + if current_time >= deadline { + return None; + } + + yield_me((deadline - current_time) as i64); + } + } +} diff --git a/libs/desim/src/lib.rs b/libs/desim/src/lib.rs new file mode 100644 index 0000000000..14f5a885c5 --- /dev/null +++ b/libs/desim/src/lib.rs @@ -0,0 +1,8 @@ +pub mod chan; +pub mod executor; +pub mod network; +pub mod node_os; +pub mod options; +pub mod proto; +pub mod time; +pub mod world; diff --git a/libs/desim/src/network.rs b/libs/desim/src/network.rs new file mode 100644 index 0000000000..e15a714daa --- /dev/null +++ b/libs/desim/src/network.rs @@ -0,0 +1,451 @@ +use std::{ + cmp::Ordering, + collections::{BinaryHeap, VecDeque}, + fmt::{self, Debug}, + ops::DerefMut, + sync::{mpsc, Arc}, +}; + +use parking_lot::{ + lock_api::{MappedMutexGuard, MutexGuard}, + Mutex, RawMutex, +}; +use rand::rngs::StdRng; +use tracing::debug; + +use crate::{ + executor::{self, ThreadContext}, + options::NetworkOptions, + proto::NetEvent, + proto::NodeEvent, +}; + +use super::{chan::Chan, proto::AnyMessage}; + +pub struct NetworkTask { + options: Arc, + connections: Mutex>, + /// min-heap of connections having something to deliver. + events: Mutex>, + task_context: Arc, +} + +impl NetworkTask { + pub fn start_new(options: Arc, tx: mpsc::Sender>) { + let ctx = executor::get_thread_ctx(); + let task = Arc::new(Self { + options, + connections: Mutex::new(Vec::new()), + events: Mutex::new(BinaryHeap::new()), + task_context: ctx, + }); + + // send the task upstream + tx.send(task.clone()).unwrap(); + + // start the task + task.start(); + } + + pub fn start_new_connection(self: &Arc, rng: StdRng, dst_accept: Chan) -> TCP { + let now = executor::now(); + let connection_id = self.connections.lock().len(); + + let vc = VirtualConnection { + connection_id, + dst_accept, + dst_sockets: [Chan::new(), Chan::new()], + state: Mutex::new(ConnectionState { + buffers: [NetworkBuffer::new(None), NetworkBuffer::new(Some(now))], + rng, + }), + }; + vc.schedule_timeout(self); + vc.send_connect(self); + + let recv_chan = vc.dst_sockets[0].clone(); + self.connections.lock().push(vc); + + TCP { + net: self.clone(), + conn_id: connection_id, + dir: 0, + recv_chan, + } + } +} + +// private functions +impl NetworkTask { + /// Schedule to wakeup network task (self) `after_ms` later to deliver + /// messages of connection `id`. + fn schedule(&self, id: usize, after_ms: u64) { + self.events.lock().push(Event { + time: executor::now() + after_ms, + conn_id: id, + }); + self.task_context.schedule_wakeup(after_ms); + } + + /// Get locked connection `id`. + fn get(&self, id: usize) -> MappedMutexGuard<'_, RawMutex, VirtualConnection> { + MutexGuard::map(self.connections.lock(), |connections| { + connections.get_mut(id).unwrap() + }) + } + + fn collect_pending_events(&self, now: u64, vec: &mut Vec) { + vec.clear(); + let mut events = self.events.lock(); + while let Some(event) = events.peek() { + if event.time > now { + break; + } + let event = events.pop().unwrap(); + vec.push(event); + } + } + + fn start(self: &Arc) { + debug!("started network task"); + + let mut events = Vec::new(); + loop { + let now = executor::now(); + self.collect_pending_events(now, &mut events); + + for event in events.drain(..) { + let conn = self.get(event.conn_id); + conn.process(self); + } + + // block until wakeup + executor::yield_me(-1); + } + } +} + +// 0 - from node(0) to node(1) +// 1 - from node(1) to node(0) +type MessageDirection = u8; + +fn sender_str(dir: MessageDirection) -> &'static str { + match dir { + 0 => "client", + 1 => "server", + _ => unreachable!(), + } +} + +fn receiver_str(dir: MessageDirection) -> &'static str { + match dir { + 0 => "server", + 1 => "client", + _ => unreachable!(), + } +} + +/// Virtual connection between two nodes. +/// Node 0 is the creator of the connection (client), +/// and node 1 is the acceptor (server). +struct VirtualConnection { + connection_id: usize, + /// one-off chan, used to deliver Accept message to dst + dst_accept: Chan, + /// message sinks + dst_sockets: [Chan; 2], + state: Mutex, +} + +struct ConnectionState { + buffers: [NetworkBuffer; 2], + rng: StdRng, +} + +impl VirtualConnection { + /// Notify the future about the possible timeout. + fn schedule_timeout(&self, net: &NetworkTask) { + if let Some(timeout) = net.options.keepalive_timeout { + net.schedule(self.connection_id, timeout); + } + } + + /// Send the handshake (Accept) to the server. + fn send_connect(&self, net: &NetworkTask) { + let now = executor::now(); + let mut state = self.state.lock(); + let delay = net.options.connect_delay.delay(&mut state.rng); + let buffer = &mut state.buffers[0]; + assert!(buffer.buf.is_empty()); + assert!(!buffer.recv_closed); + assert!(!buffer.send_closed); + assert!(buffer.last_recv.is_none()); + + let delay = if let Some(ms) = delay { + ms + } else { + debug!("NET: TCP #{} dropped connect", self.connection_id); + buffer.send_closed = true; + return; + }; + + // Send a message into the future. + buffer + .buf + .push_back((now + delay, AnyMessage::InternalConnect)); + net.schedule(self.connection_id, delay); + } + + /// Transmit some of the messages from the buffer to the nodes. + fn process(&self, net: &Arc) { + let now = executor::now(); + + let mut state = self.state.lock(); + + for direction in 0..2 { + self.process_direction( + net, + state.deref_mut(), + now, + direction as MessageDirection, + &self.dst_sockets[direction ^ 1], + ); + } + + // Close the one side of the connection by timeout if the node + // has not received any messages for a long time. + if let Some(timeout) = net.options.keepalive_timeout { + let mut to_close = [false, false]; + for direction in 0..2 { + let buffer = &mut state.buffers[direction]; + if buffer.recv_closed { + continue; + } + if let Some(last_recv) = buffer.last_recv { + if now - last_recv >= timeout { + debug!( + "NET: connection {} timed out at {}", + self.connection_id, + receiver_str(direction as MessageDirection) + ); + let node_idx = direction ^ 1; + to_close[node_idx] = true; + } + } + } + drop(state); + + for (node_idx, should_close) in to_close.iter().enumerate() { + if *should_close { + self.close(node_idx); + } + } + } + } + + /// Process messages in the buffer in the given direction. + fn process_direction( + &self, + net: &Arc, + state: &mut ConnectionState, + now: u64, + direction: MessageDirection, + to_socket: &Chan, + ) { + let buffer = &mut state.buffers[direction as usize]; + if buffer.recv_closed { + assert!(buffer.buf.is_empty()); + } + + while !buffer.buf.is_empty() && buffer.buf.front().unwrap().0 <= now { + let msg = buffer.buf.pop_front().unwrap().1; + + buffer.last_recv = Some(now); + self.schedule_timeout(net); + + if let AnyMessage::InternalConnect = msg { + // TODO: assert to_socket is the server + let server_to_client = TCP { + net: net.clone(), + conn_id: self.connection_id, + dir: direction ^ 1, + recv_chan: to_socket.clone(), + }; + // special case, we need to deliver new connection to a separate channel + self.dst_accept.send(NodeEvent::Accept(server_to_client)); + } else { + to_socket.send(NetEvent::Message(msg)); + } + } + } + + /// Try to send a message to the buffer, optionally dropping it and + /// determining delivery timestamp. + fn send(&self, net: &NetworkTask, direction: MessageDirection, msg: AnyMessage) { + let now = executor::now(); + let mut state = self.state.lock(); + + let (delay, close) = if let Some(ms) = net.options.send_delay.delay(&mut state.rng) { + (ms, false) + } else { + (0, true) + }; + + let buffer = &mut state.buffers[direction as usize]; + if buffer.send_closed { + debug!( + "NET: TCP #{} dropped message {:?} (broken pipe)", + self.connection_id, msg + ); + return; + } + + if close { + debug!( + "NET: TCP #{} dropped message {:?} (pipe just broke)", + self.connection_id, msg + ); + buffer.send_closed = true; + return; + } + + if buffer.recv_closed { + debug!( + "NET: TCP #{} dropped message {:?} (recv closed)", + self.connection_id, msg + ); + return; + } + + // Send a message into the future. + buffer.buf.push_back((now + delay, msg)); + net.schedule(self.connection_id, delay); + } + + /// Close the connection. Only one side of the connection will be closed, + /// and no further messages will be delivered. The other side will not be notified. + fn close(&self, node_idx: usize) { + let mut state = self.state.lock(); + let recv_buffer = &mut state.buffers[1 ^ node_idx]; + if recv_buffer.recv_closed { + debug!( + "NET: TCP #{} closed twice at {}", + self.connection_id, + sender_str(node_idx as MessageDirection), + ); + return; + } + + debug!( + "NET: TCP #{} closed at {}", + self.connection_id, + sender_str(node_idx as MessageDirection), + ); + recv_buffer.recv_closed = true; + for msg in recv_buffer.buf.drain(..) { + debug!( + "NET: TCP #{} dropped message {:?} (closed)", + self.connection_id, msg + ); + } + + let send_buffer = &mut state.buffers[node_idx]; + send_buffer.send_closed = true; + drop(state); + + // TODO: notify the other side? + + self.dst_sockets[node_idx].send(NetEvent::Closed); + } +} + +struct NetworkBuffer { + /// Messages paired with time of delivery + buf: VecDeque<(u64, AnyMessage)>, + /// True if the connection is closed on the receiving side, + /// i.e. no more messages from the buffer will be delivered. + recv_closed: bool, + /// True if the connection is closed on the sending side, + /// i.e. no more messages will be added to the buffer. + send_closed: bool, + /// Last time a message was delivered from the buffer. + /// If None, it means that the server is the receiver and + /// it has not yet aware of this connection (i.e. has not + /// received the Accept). + last_recv: Option, +} + +impl NetworkBuffer { + fn new(last_recv: Option) -> Self { + Self { + buf: VecDeque::new(), + recv_closed: false, + send_closed: false, + last_recv, + } + } +} + +/// Single end of a bidirectional network stream without reordering (TCP-like). +/// Reads are implemented using channels, writes go to the buffer inside VirtualConnection. +pub struct TCP { + net: Arc, + conn_id: usize, + dir: MessageDirection, + recv_chan: Chan, +} + +impl Debug for TCP { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TCP #{} ({})", self.conn_id, sender_str(self.dir),) + } +} + +impl TCP { + /// Send a message to the other side. It's guaranteed that it will not arrive + /// before the arrival of all messages sent earlier. + pub fn send(&self, msg: AnyMessage) { + let conn = self.net.get(self.conn_id); + conn.send(&self.net, self.dir, msg); + } + + /// Get a channel to receive incoming messages. + pub fn recv_chan(&self) -> Chan { + self.recv_chan.clone() + } + + pub fn connection_id(&self) -> usize { + self.conn_id + } + + pub fn close(&self) { + let conn = self.net.get(self.conn_id); + conn.close(self.dir as usize); + } +} +struct Event { + time: u64, + conn_id: usize, +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Event { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Event { + fn cmp(&self, other: &Self) -> Ordering { + (other.time, other.conn_id).cmp(&(self.time, self.conn_id)) + } +} + +impl PartialEq for Event { + fn eq(&self, other: &Self) -> bool { + (other.time, other.conn_id) == (self.time, self.conn_id) + } +} + +impl Eq for Event {} diff --git a/libs/desim/src/node_os.rs b/libs/desim/src/node_os.rs new file mode 100644 index 0000000000..7744a9f5e1 --- /dev/null +++ b/libs/desim/src/node_os.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use rand::Rng; + +use crate::proto::NodeEvent; + +use super::{ + chan::Chan, + network::TCP, + world::{Node, NodeId, World}, +}; + +/// Abstraction with all functions (aka syscalls) available to the node. +#[derive(Clone)] +pub struct NodeOs { + world: Arc, + internal: Arc, +} + +impl NodeOs { + pub fn new(world: Arc, internal: Arc) -> NodeOs { + NodeOs { world, internal } + } + + /// Get the node id. + pub fn id(&self) -> NodeId { + self.internal.id + } + + /// Opens a bidirectional connection with the other node. Always successful. + pub fn open_tcp(&self, dst: NodeId) -> TCP { + self.world.open_tcp(dst) + } + + /// Returns a channel to receive node events (socket Accept and internal messages). + pub fn node_events(&self) -> Chan { + self.internal.node_events() + } + + /// Get current time. + pub fn now(&self) -> u64 { + self.world.now() + } + + /// Generate a random number in range [0, max). + pub fn random(&self, max: u64) -> u64 { + self.internal.rng.lock().gen_range(0..max) + } + + /// Append a new event to the world event log. + pub fn log_event(&self, data: String) { + self.internal.log_event(data) + } +} diff --git a/libs/desim/src/options.rs b/libs/desim/src/options.rs new file mode 100644 index 0000000000..5da7c2c482 --- /dev/null +++ b/libs/desim/src/options.rs @@ -0,0 +1,50 @@ +use rand::{rngs::StdRng, Rng}; + +/// Describes random delays and failures. Delay will be uniformly distributed in [min, max]. +/// Connection failure will occur with the probablity fail_prob. +#[derive(Clone, Debug)] +pub struct Delay { + pub min: u64, + pub max: u64, + pub fail_prob: f64, // [0; 1] +} + +impl Delay { + /// Create a struct with no delay, no failures. + pub fn empty() -> Delay { + Delay { + min: 0, + max: 0, + fail_prob: 0.0, + } + } + + /// Create a struct with a fixed delay. + pub fn fixed(ms: u64) -> Delay { + Delay { + min: ms, + max: ms, + fail_prob: 0.0, + } + } + + /// Generate a random delay in range [min, max]. Return None if the + /// message should be dropped. + pub fn delay(&self, rng: &mut StdRng) -> Option { + if rng.gen_bool(self.fail_prob) { + return None; + } + Some(rng.gen_range(self.min..=self.max)) + } +} + +/// Describes network settings. All network packets will be subjected to the same delays and failures. +#[derive(Clone, Debug)] +pub struct NetworkOptions { + /// Connection will be automatically closed after this timeout if no data is received. + pub keepalive_timeout: Option, + /// New connections will be delayed by this amount of time. + pub connect_delay: Delay, + /// Each message will be delayed by this amount of time. + pub send_delay: Delay, +} diff --git a/libs/desim/src/proto.rs b/libs/desim/src/proto.rs new file mode 100644 index 0000000000..92a7e8a27d --- /dev/null +++ b/libs/desim/src/proto.rs @@ -0,0 +1,63 @@ +use std::fmt::Debug; + +use bytes::Bytes; +use utils::lsn::Lsn; + +use crate::{network::TCP, world::NodeId}; + +/// Internal node events. +#[derive(Debug)] +pub enum NodeEvent { + Accept(TCP), + Internal(AnyMessage), +} + +/// Events that are coming from a network socket. +#[derive(Clone, Debug)] +pub enum NetEvent { + Message(AnyMessage), + Closed, +} + +/// Custom events generated throughout the simulation. Can be used by the test to verify the correctness. +#[derive(Debug)] +pub struct SimEvent { + pub time: u64, + pub node: NodeId, + pub data: String, +} + +/// Umbrella type for all possible flavours of messages. These events can be sent over network +/// or to an internal node events channel. +#[derive(Clone)] +pub enum AnyMessage { + /// Not used, empty placeholder. + None, + /// Used internally for notifying node about new incoming connection. + InternalConnect, + Just32(u32), + ReplCell(ReplCell), + Bytes(Bytes), + LSN(u64), +} + +impl Debug for AnyMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AnyMessage::None => write!(f, "None"), + AnyMessage::InternalConnect => write!(f, "InternalConnect"), + AnyMessage::Just32(v) => write!(f, "Just32({})", v), + AnyMessage::ReplCell(v) => write!(f, "ReplCell({:?})", v), + AnyMessage::Bytes(v) => write!(f, "Bytes({})", hex::encode(v)), + AnyMessage::LSN(v) => write!(f, "LSN({})", Lsn(*v)), + } + } +} + +/// Used in reliable_copy_test.rs +#[derive(Clone, Debug)] +pub struct ReplCell { + pub value: u32, + pub client_id: u32, + pub seqno: u32, +} diff --git a/libs/desim/src/time.rs b/libs/desim/src/time.rs new file mode 100644 index 0000000000..7bb71db95c --- /dev/null +++ b/libs/desim/src/time.rs @@ -0,0 +1,129 @@ +use std::{ + cmp::Ordering, + collections::BinaryHeap, + ops::DerefMut, + sync::{ + atomic::{AtomicU32, AtomicU64}, + Arc, + }, +}; + +use parking_lot::Mutex; +use tracing::trace; + +use crate::executor::ThreadContext; + +/// Holds current time and all pending wakeup events. +pub struct Timing { + /// Current world's time. + current_time: AtomicU64, + /// Pending timers. + queue: Mutex>, + /// Global nonce. Makes picking events from binary heap queue deterministic + /// by appending a number to events with the same timestamp. + nonce: AtomicU32, + /// Used to schedule fake events. + fake_context: Arc, +} + +impl Default for Timing { + fn default() -> Self { + Self::new() + } +} + +impl Timing { + /// Create a new empty clock with time set to 0. + pub fn new() -> Timing { + Timing { + current_time: AtomicU64::new(0), + queue: Mutex::new(BinaryHeap::new()), + nonce: AtomicU32::new(0), + fake_context: Arc::new(ThreadContext::new()), + } + } + + /// Return the current world's time. + pub fn now(&self) -> u64 { + self.current_time.load(std::sync::atomic::Ordering::SeqCst) + } + + /// Tick-tock the global clock. Return the event ready to be processed + /// or move the clock forward and then return the event. + pub(crate) fn step(&self) -> Option> { + let mut queue = self.queue.lock(); + + if queue.is_empty() { + // no future events + return None; + } + + if !self.is_event_ready(queue.deref_mut()) { + let next_time = queue.peek().unwrap().time; + self.current_time + .store(next_time, std::sync::atomic::Ordering::SeqCst); + trace!("rewind time to {}", next_time); + assert!(self.is_event_ready(queue.deref_mut())); + } + + Some(queue.pop().unwrap().wake_context) + } + + /// Append an event to the queue, to wakeup the thread in `ms` milliseconds. + pub(crate) fn schedule_wakeup(&self, ms: u64, wake_context: Arc) { + self.nonce.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let nonce = self.nonce.load(std::sync::atomic::Ordering::SeqCst); + self.queue.lock().push(Pending { + time: self.now() + ms, + nonce, + wake_context, + }) + } + + /// Append a fake event to the queue, to prevent clocks from skipping this time. + pub fn schedule_fake(&self, ms: u64) { + self.queue.lock().push(Pending { + time: self.now() + ms, + nonce: 0, + wake_context: self.fake_context.clone(), + }); + } + + /// Return true if there is a ready event. + fn is_event_ready(&self, queue: &mut BinaryHeap) -> bool { + queue.peek().map_or(false, |x| x.time <= self.now()) + } + + /// Clear all pending events. + pub(crate) fn clear(&self) { + self.queue.lock().clear(); + } +} + +struct Pending { + time: u64, + nonce: u32, + wake_context: Arc, +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Pending { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Pending { + fn cmp(&self, other: &Self) -> Ordering { + (other.time, other.nonce).cmp(&(self.time, self.nonce)) + } +} + +impl PartialEq for Pending { + fn eq(&self, other: &Self) -> bool { + (other.time, other.nonce) == (self.time, self.nonce) + } +} + +impl Eq for Pending {} diff --git a/libs/desim/src/world.rs b/libs/desim/src/world.rs new file mode 100644 index 0000000000..7d60be04b5 --- /dev/null +++ b/libs/desim/src/world.rs @@ -0,0 +1,180 @@ +use parking_lot::Mutex; +use rand::{rngs::StdRng, SeedableRng}; +use std::{ + ops::DerefMut, + sync::{mpsc, Arc}, +}; + +use crate::{ + executor::{ExternalHandle, Runtime}, + network::NetworkTask, + options::NetworkOptions, + proto::{NodeEvent, SimEvent}, + time::Timing, +}; + +use super::{chan::Chan, network::TCP, node_os::NodeOs}; + +pub type NodeId = u32; + +/// World contains simulation state. +pub struct World { + nodes: Mutex>>, + /// Random number generator. + rng: Mutex, + /// Internal event log. + events: Mutex>, + /// Separate task that processes all network messages. + network_task: Arc, + /// Runtime for running threads and moving time. + runtime: Mutex, + /// To get current time. + timing: Arc, +} + +impl World { + pub fn new(seed: u64, options: Arc) -> World { + let timing = Arc::new(Timing::new()); + let mut runtime = Runtime::new(timing.clone()); + + let (tx, rx) = mpsc::channel(); + + runtime.spawn(move || { + // create and start network background thread, and send it back via the channel + NetworkTask::start_new(options, tx) + }); + + // wait for the network task to start + while runtime.step() {} + + let network_task = rx.recv().unwrap(); + + World { + nodes: Mutex::new(Vec::new()), + rng: Mutex::new(StdRng::seed_from_u64(seed)), + events: Mutex::new(Vec::new()), + network_task, + runtime: Mutex::new(runtime), + timing, + } + } + + pub fn step(&self) -> bool { + self.runtime.lock().step() + } + + pub fn get_thread_step_count(&self) -> u64 { + self.runtime.lock().step_counter + } + + /// Create a new random number generator. + pub fn new_rng(&self) -> StdRng { + let mut rng = self.rng.lock(); + StdRng::from_rng(rng.deref_mut()).unwrap() + } + + /// Create a new node. + pub fn new_node(self: &Arc) -> Arc { + let mut nodes = self.nodes.lock(); + let id = nodes.len() as NodeId; + let node = Arc::new(Node::new(id, self.clone(), self.new_rng())); + nodes.push(node.clone()); + node + } + + /// Get an internal node state by id. + fn get_node(&self, id: NodeId) -> Option> { + let nodes = self.nodes.lock(); + let num = id as usize; + if num < nodes.len() { + Some(nodes[num].clone()) + } else { + None + } + } + + pub fn stop_all(&self) { + self.runtime.lock().crash_all_threads(); + } + + /// Returns a writable end of a TCP connection, to send src->dst messages. + pub fn open_tcp(self: &Arc, dst: NodeId) -> TCP { + // TODO: replace unwrap() with /dev/null socket. + let dst = self.get_node(dst).unwrap(); + let dst_accept = dst.node_events.lock().clone(); + + let rng = self.new_rng(); + self.network_task.start_new_connection(rng, dst_accept) + } + + /// Get current time. + pub fn now(&self) -> u64 { + self.timing.now() + } + + /// Get a copy of the internal clock. + pub fn clock(&self) -> Arc { + self.timing.clone() + } + + pub fn add_event(&self, node: NodeId, data: String) { + let time = self.now(); + self.events.lock().push(SimEvent { time, node, data }); + } + + pub fn take_events(&self) -> Vec { + let mut events = self.events.lock(); + let mut res = Vec::new(); + std::mem::swap(&mut res, &mut events); + res + } + + pub fn deallocate(&self) { + self.stop_all(); + self.timing.clear(); + self.nodes.lock().clear(); + } +} + +/// Internal node state. +pub struct Node { + pub id: NodeId, + node_events: Mutex>, + world: Arc, + pub(crate) rng: Mutex, +} + +impl Node { + pub fn new(id: NodeId, world: Arc, rng: StdRng) -> Node { + Node { + id, + node_events: Mutex::new(Chan::new()), + world, + rng: Mutex::new(rng), + } + } + + /// Spawn a new thread with this node context. + pub fn launch(self: &Arc, f: impl FnOnce(NodeOs) + Send + 'static) -> ExternalHandle { + let node = self.clone(); + let world = self.world.clone(); + self.world.runtime.lock().spawn(move || { + f(NodeOs::new(world, node.clone())); + }) + } + + /// Returns a channel to receive Accepts and internal messages. + pub fn node_events(&self) -> Chan { + self.node_events.lock().clone() + } + + /// This will drop all in-flight Accept messages. + pub fn replug_node_events(&self, chan: Chan) { + *self.node_events.lock() = chan; + } + + /// Append event to the world's log. + pub fn log_event(&self, data: String) { + self.world.add_event(self.id, data) + } +} diff --git a/libs/desim/tests/reliable_copy_test.rs b/libs/desim/tests/reliable_copy_test.rs new file mode 100644 index 0000000000..cf7bff8f5a --- /dev/null +++ b/libs/desim/tests/reliable_copy_test.rs @@ -0,0 +1,244 @@ +//! Simple test to verify that simulator is working. +#[cfg(test)] +mod reliable_copy_test { + use anyhow::Result; + use desim::executor::{self, PollSome}; + use desim::options::{Delay, NetworkOptions}; + use desim::proto::{NetEvent, NodeEvent, ReplCell}; + use desim::world::{NodeId, World}; + use desim::{node_os::NodeOs, proto::AnyMessage}; + use parking_lot::Mutex; + use std::sync::Arc; + use tracing::info; + + /// Disk storage trait and implementation. + pub trait Storage { + fn flush_pos(&self) -> u32; + fn flush(&mut self) -> Result<()>; + fn write(&mut self, t: T); + } + + #[derive(Clone)] + pub struct SharedStorage { + pub state: Arc>>, + } + + impl SharedStorage { + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(InMemoryStorage::new())), + } + } + } + + impl Storage for SharedStorage { + fn flush_pos(&self) -> u32 { + self.state.lock().flush_pos + } + + fn flush(&mut self) -> Result<()> { + executor::yield_me(0); + self.state.lock().flush() + } + + fn write(&mut self, t: T) { + executor::yield_me(0); + self.state.lock().write(t); + } + } + + pub struct InMemoryStorage { + pub data: Vec, + pub flush_pos: u32, + } + + impl InMemoryStorage { + pub fn new() -> Self { + Self { + data: Vec::new(), + flush_pos: 0, + } + } + + pub fn flush(&mut self) -> Result<()> { + self.flush_pos = self.data.len() as u32; + Ok(()) + } + + pub fn write(&mut self, t: T) { + self.data.push(t); + } + } + + /// Server implementation. + pub fn run_server(os: NodeOs, mut storage: Box>) { + info!("started server"); + + let node_events = os.node_events(); + let mut epoll_vec: Vec> = vec![Box::new(node_events.clone())]; + let mut sockets = vec![]; + + loop { + let index = executor::epoll_chans(&epoll_vec, -1).unwrap(); + + if index == 0 { + let node_event = node_events.must_recv(); + info!("got node event: {:?}", node_event); + if let NodeEvent::Accept(tcp) = node_event { + tcp.send(AnyMessage::Just32(storage.flush_pos())); + epoll_vec.push(Box::new(tcp.recv_chan())); + sockets.push(tcp); + } + continue; + } + + let recv_chan = sockets[index - 1].recv_chan(); + let socket = &sockets[index - 1]; + + let event = recv_chan.must_recv(); + info!("got event: {:?}", event); + if let NetEvent::Message(AnyMessage::ReplCell(cell)) = event { + if cell.seqno != storage.flush_pos() { + info!("got out of order data: {:?}", cell); + continue; + } + storage.write(cell.value); + storage.flush().unwrap(); + socket.send(AnyMessage::Just32(storage.flush_pos())); + } + } + } + + /// Client copies all data from array to the remote node. + pub fn run_client(os: NodeOs, data: &[ReplCell], dst: NodeId) { + info!("started client"); + + let mut delivered = 0; + + let mut sock = os.open_tcp(dst); + let mut recv_chan = sock.recv_chan(); + + while delivered < data.len() { + let num = &data[delivered]; + info!("sending data: {:?}", num.clone()); + sock.send(AnyMessage::ReplCell(num.clone())); + + // loop { + let event = recv_chan.recv(); + match event { + NetEvent::Message(AnyMessage::Just32(flush_pos)) => { + if flush_pos == 1 + delivered as u32 { + delivered += 1; + } + } + NetEvent::Closed => { + info!("connection closed, reestablishing"); + sock = os.open_tcp(dst); + recv_chan = sock.recv_chan(); + } + _ => {} + } + + // } + } + + let sock = os.open_tcp(dst); + for num in data { + info!("sending data: {:?}", num.clone()); + sock.send(AnyMessage::ReplCell(num.clone())); + } + + info!("sent all data and finished client"); + } + + /// Run test simulations. + #[test] + fn sim_example_reliable_copy() { + utils::logging::init( + utils::logging::LogFormat::Test, + utils::logging::TracingErrorLayerEnablement::Disabled, + utils::logging::Output::Stdout, + ) + .expect("logging init failed"); + + let delay = Delay { + min: 1, + max: 60, + fail_prob: 0.4, + }; + + let network = NetworkOptions { + keepalive_timeout: Some(50), + connect_delay: delay.clone(), + send_delay: delay.clone(), + }; + + for seed in 0..20 { + let u32_data: [u32; 5] = [1, 2, 3, 4, 5]; + let data = u32_to_cells(&u32_data, 1); + let world = Arc::new(World::new(seed, Arc::new(network.clone()))); + + start_simulation(Options { + world, + time_limit: 1_000_000, + client_fn: Box::new(move |os, server_id| run_client(os, &data, server_id)), + u32_data, + }); + } + } + + pub struct Options { + pub world: Arc, + pub time_limit: u64, + pub u32_data: [u32; 5], + pub client_fn: Box, + } + + pub fn start_simulation(options: Options) { + let world = options.world; + + let client_node = world.new_node(); + let server_node = world.new_node(); + let server_id = server_node.id; + + // start the client thread + client_node.launch(move |os| { + let client_fn = options.client_fn; + client_fn(os, server_id); + }); + + // start the server thread + let shared_storage = SharedStorage::new(); + let server_storage = shared_storage.clone(); + server_node.launch(move |os| run_server(os, Box::new(server_storage))); + + while world.step() && world.now() < options.time_limit {} + + let disk_data = shared_storage.state.lock().data.clone(); + assert!(verify_data(&disk_data, &options.u32_data[..])); + } + + pub fn u32_to_cells(data: &[u32], client_id: u32) -> Vec { + let mut res = Vec::new(); + for (i, _) in data.iter().enumerate() { + res.push(ReplCell { + client_id, + seqno: i as u32, + value: data[i], + }); + } + res + } + + fn verify_data(disk_data: &[u32], data: &[u32]) -> bool { + if disk_data.len() != data.len() { + return false; + } + for i in 0..data.len() { + if disk_data[i] != data[i] { + return false; + } + } + true + } +} diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index a863fad269..977653848d 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -431,11 +431,11 @@ pub fn generate_wal_segment(segno: u64, system_id: u64, lsn: Lsn) -> Result anyhow::Result<()> { println!("cargo:rustc-link-lib=static=walproposer"); println!("cargo:rustc-link-search={walproposer_lib_search_str}"); + // Rebuild crate when libwalproposer.a changes + println!("cargo:rerun-if-changed={walproposer_lib_search_str}/libwalproposer.a"); + let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config"); let inc_server_path: String = if pg_config_bin.exists() { let output = Command::new(pg_config_bin) @@ -79,6 +82,7 @@ fn main() -> anyhow::Result<()> { .allowlist_function("WalProposerBroadcast") .allowlist_function("WalProposerPoll") .allowlist_function("WalProposerFree") + .allowlist_function("SafekeeperStateDesiredEvents") .allowlist_var("DEBUG5") .allowlist_var("DEBUG4") .allowlist_var("DEBUG3") diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 1f7bf952dc..8317e2fa03 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -22,6 +22,7 @@ use crate::bindings::WalProposerExecStatusType; use crate::bindings::WalproposerShmemState; use crate::bindings::XLogRecPtr; use crate::walproposer::ApiImpl; +use crate::walproposer::StreamingCallback; use crate::walproposer::WaitResult; extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState { @@ -36,7 +37,8 @@ extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) { unsafe { let callback_data = (*(*wp).config).callback_data; let api = callback_data as *mut Box; - (*api).start_streaming(startpos) + let callback = StreamingCallback::new(wp); + (*api).start_streaming(startpos, &callback); } } @@ -134,19 +136,18 @@ extern "C" fn conn_async_read( unsafe { let callback_data = (*(*(*sk).wp).config).callback_data; let api = callback_data as *mut Box; - let (res, result) = (*api).conn_async_read(&mut (*sk)); // This function has guarantee that returned buf will be valid until // the next call. So we can store a Vec in each Safekeeper and reuse // it on the next call. let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default(); - inbuf.clear(); - inbuf.extend_from_slice(res); + + let result = (*api).conn_async_read(&mut (*sk), &mut inbuf); // Put a Vec back to sk->inbuf and return data ptr. + *amount = inbuf.len() as i32; *buf = store_vec_u8(&mut (*sk).inbuf, inbuf); - *amount = res.len() as i32; result } @@ -182,6 +183,10 @@ extern "C" fn recovery_download(wp: *mut WalProposer, sk: *mut Safekeeper) -> bo unsafe { let callback_data = (*(*(*sk).wp).config).callback_data; let api = callback_data as *mut Box; + + // currently `recovery_download` is always called right after election + (*api).after_election(&mut (*wp)); + (*api).recovery_download(&mut (*wp), &mut (*sk)) } } @@ -277,7 +282,8 @@ extern "C" fn wait_event_set( } WaitResult::Timeout => { *event_sk = std::ptr::null_mut(); - *events = crate::bindings::WL_TIMEOUT; + // WaitEventSetWait returns 0 for timeout. + *events = 0; 0 } WaitResult::Network(sk, event_mask) => { @@ -340,7 +346,7 @@ extern "C" fn log_internal( } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Level { Debug5, Debug4, diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 8ab8fb1a07..13fade220c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -1,13 +1,13 @@ use std::ffi::CString; use postgres_ffi::WAL_SEGMENT_SIZE; -use utils::id::TenantTimelineId; +use utils::{id::TenantTimelineId, lsn::Lsn}; use crate::{ api_bindings::{create_api, take_vec_u8, Level}, bindings::{ - NeonWALReadResult, Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate, - WalProposerFree, WalProposerStart, + NeonWALReadResult, Safekeeper, WalProposer, WalProposerBroadcast, WalProposerConfig, + WalProposerCreate, WalProposerFree, WalProposerPoll, WalProposerStart, }, }; @@ -16,11 +16,11 @@ use crate::{ /// /// Refer to `pgxn/neon/walproposer.h` for documentation. pub trait ApiImpl { - fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState { + fn get_shmem_state(&self) -> *mut crate::bindings::WalproposerShmemState { todo!() } - fn start_streaming(&self, _startpos: u64) { + fn start_streaming(&self, _startpos: u64, _callback: &StreamingCallback) { todo!() } @@ -70,7 +70,11 @@ pub trait ApiImpl { todo!() } - fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) { + fn conn_async_read( + &self, + _sk: &mut Safekeeper, + _vec: &mut Vec, + ) -> crate::bindings::PGAsyncReadResult { todo!() } @@ -151,12 +155,14 @@ pub trait ApiImpl { } } +#[derive(Debug)] pub enum WaitResult { Latch, Timeout, Network(*mut Safekeeper, u32), } +#[derive(Clone)] pub struct Config { /// Tenant and timeline id pub ttid: TenantTimelineId, @@ -242,6 +248,24 @@ impl Drop for Wrapper { } } +pub struct StreamingCallback { + wp: *mut WalProposer, +} + +impl StreamingCallback { + pub fn new(wp: *mut WalProposer) -> StreamingCallback { + StreamingCallback { wp } + } + + pub fn broadcast(&self, startpos: Lsn, endpos: Lsn) { + unsafe { WalProposerBroadcast(self.wp, startpos.0, endpos.0) } + } + + pub fn poll(&self) { + unsafe { WalProposerPoll(self.wp) } + } +} + #[cfg(test)] mod tests { use core::panic; @@ -344,14 +368,13 @@ mod tests { fn conn_async_read( &self, _: &mut crate::bindings::Safekeeper, - ) -> (&[u8], crate::bindings::PGAsyncReadResult) { + vec: &mut Vec, + ) -> crate::bindings::PGAsyncReadResult { println!("conn_async_read"); let reply = self.next_safekeeper_reply(); println!("conn_async_read result: {:?}", reply); - ( - reply, - crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS, - ) + vec.extend_from_slice(reply); + crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS } fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool { diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 93d1dcab35..12ceac0191 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -346,7 +346,7 @@ impl WalIngest { let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK; if info == pg_constants::XLOG_LOGICAL_MESSAGE { - let xlrec = XlLogicalMessage::decode(&mut buf); + let xlrec = crate::walrecord::XlLogicalMessage::decode(&mut buf); let prefix = std::str::from_utf8(&buf[0..xlrec.prefix_size - 1])?; let message = &buf[xlrec.prefix_size..xlrec.prefix_size + xlrec.message_size]; if prefix == "neon-test" { diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 171af7d2aa..0d5007ef73 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -688,7 +688,7 @@ RecvAcceptorGreeting(Safekeeper *sk) if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->greetResponse)) return; - wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port); + wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s, term=" INT64_FORMAT, sk->host, sk->port, sk->greetResponse.term); /* Protocol is all good, move to voting. */ sk->state = SS_VOTING; @@ -922,6 +922,7 @@ static void DetermineEpochStartLsn(WalProposer *wp) { TermHistory *dth; + int n_ready = 0; wp->propEpochStartLsn = InvalidXLogRecPtr; wp->donorEpoch = 0; @@ -932,6 +933,8 @@ DetermineEpochStartLsn(WalProposer *wp) { if (wp->safekeeper[i].state == SS_IDLE) { + n_ready++; + if (GetEpoch(&wp->safekeeper[i]) > wp->donorEpoch || (GetEpoch(&wp->safekeeper[i]) == wp->donorEpoch && wp->safekeeper[i].voteResponse.flushLsn > wp->propEpochStartLsn)) @@ -958,6 +961,16 @@ DetermineEpochStartLsn(WalProposer *wp) } } + if (n_ready < wp->quorum) + { + /* + * This is a rare case that can be triggered if safekeeper has voted and disconnected. + * In this case, its state will not be SS_IDLE and its vote cannot be used, because + * we clean up `voteResponse` in `ShutdownConnection`. + */ + wp_log(FATAL, "missing majority of votes, collected %d, expected %d, got %d", wp->n_votes, wp->quorum, n_ready); + } + /* * If propEpochStartLsn is 0, it means flushLsn is 0 everywhere, we are bootstrapping * and nothing was committed yet. Start streaming then from the basebackup LSN. diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 688d8e6e52..53820f6e1b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -486,6 +486,8 @@ typedef struct walproposer_api * * On success, the data is placed in *buf. It is valid until the next call * to this function. + * + * Returns PG_ASYNC_READ_FAIL on closed connection. */ PGAsyncReadResult (*conn_async_read) (Safekeeper *sk, char **buf, int *amount); @@ -532,6 +534,13 @@ typedef struct walproposer_api * Returns 0 if timeout is reached, 1 if some event happened. Updates * events mask to indicate events and sets sk to the safekeeper which has * an event. + * + * On timeout, events is set to WL_NO_EVENTS. On socket event, events is + * set to WL_SOCKET_READABLE and/or WL_SOCKET_WRITEABLE. When socket is + * closed, events is set to WL_SOCKET_READABLE. + * + * WL_SOCKET_WRITEABLE is usually set only when we need to flush the buffer. + * It can be returned only if caller asked for this event in the last *_event_set call. */ int (*wait_event_set) (WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events); diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 364cad7892..cb4a1def1f 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -61,3 +61,10 @@ tokio-stream.workspace = true utils.workspace = true workspace_hack.workspace = true + +[dev-dependencies] +walproposer.workspace = true +rand.workspace = true +desim.workspace = true +tracing.workspace = true +tracing-subscriber = { workspace = true, features = ["json"] } diff --git a/safekeeper/tests/misc_test.rs b/safekeeper/tests/misc_test.rs new file mode 100644 index 0000000000..8e5b17a143 --- /dev/null +++ b/safekeeper/tests/misc_test.rs @@ -0,0 +1,155 @@ +use std::sync::Arc; + +use tracing::{info, warn}; +use utils::lsn::Lsn; + +use crate::walproposer_sim::{ + log::{init_logger, init_tracing_logger}, + simulation::{generate_network_opts, generate_schedule, Schedule, TestAction, TestConfig}, +}; + +pub mod walproposer_sim; + +// Test that simulation supports restarting (crashing) safekeepers. +#[test] +fn crash_safekeeper() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + // Write some WAL and crash safekeeper 0 without waiting for replication. + test.poll_for_duration(30); + wp.write_tx(3); + test.servers[0].restart(); + + // Wait some time, so that walproposer can reconnect. + test.poll_for_duration(2000); +} + +// Test that walproposer can be crashed (stopped). +#[test] +fn test_simple_restart() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + test.poll_for_duration(30); + wp.write_tx(3); + test.poll_for_duration(100); + + wp.stop(); + drop(wp); + + let lsn = test.sync_safekeepers().unwrap(); + info!("Sucessfully synced safekeepers at {}", lsn); +} + +// Test runnning a simple schedule, restarting everything a several times. +#[test] +fn test_simple_schedule() -> anyhow::Result<()> { + let clock = init_logger(); + let mut config = TestConfig::new(Some(clock)); + config.network.keepalive_timeout = Some(100); + let test = config.start(1337); + + let schedule: Schedule = vec![ + (0, TestAction::RestartWalProposer), + (50, TestAction::WriteTx(5)), + (100, TestAction::RestartSafekeeper(0)), + (100, TestAction::WriteTx(5)), + (110, TestAction::RestartSafekeeper(1)), + (110, TestAction::WriteTx(5)), + (120, TestAction::RestartSafekeeper(2)), + (120, TestAction::WriteTx(5)), + (201, TestAction::RestartWalProposer), + (251, TestAction::RestartSafekeeper(0)), + (251, TestAction::RestartSafekeeper(1)), + (251, TestAction::RestartSafekeeper(2)), + (251, TestAction::WriteTx(5)), + (255, TestAction::WriteTx(5)), + (1000, TestAction::WriteTx(5)), + ]; + + test.run_schedule(&schedule)?; + info!("Test finished, stopping all threads"); + test.world.deallocate(); + + Ok(()) +} + +// Test that simulation can process 10^4 transactions. +#[test] +fn test_many_tx() -> anyhow::Result<()> { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let mut schedule: Schedule = vec![]; + for i in 0..100 { + schedule.push((i * 10, TestAction::WriteTx(100))); + } + + test.run_schedule(&schedule)?; + info!("Test finished, stopping all threads"); + test.world.stop_all(); + + let events = test.world.take_events(); + info!("Events: {:?}", events); + let last_commit_lsn = events + .iter() + .filter_map(|event| { + if event.data.starts_with("commit_lsn;") { + let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap(); + return Some(lsn); + } + None + }) + .last() + .unwrap(); + + let initdb_lsn = 21623024; + let diff = last_commit_lsn - initdb_lsn; + info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff); + // each tx is at least 8 bytes, it's written a 100 times for in a loop for 100 times + assert!(diff > 100 * 100 * 8); + Ok(()) +} + +// Checks that we don't have nasty circular dependencies, preventing Arc from deallocating. +// This test doesn't really assert anything, you need to run it manually to check if there +// is any issue. +#[test] +fn test_res_dealloc() -> anyhow::Result<()> { + let clock = init_tracing_logger(true); + let mut config = TestConfig::new(Some(clock)); + + let seed = 123456; + config.network = generate_network_opts(seed); + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + info!("schedule: {:?}", schedule); + test.run_schedule(&schedule).unwrap(); + test.world.stop_all(); + + let world = test.world.clone(); + drop(test); + info!("world strong count: {}", Arc::strong_count(&world)); + world.deallocate(); + info!("world strong count: {}", Arc::strong_count(&world)); + + Ok(()) +} diff --git a/safekeeper/tests/random_test.rs b/safekeeper/tests/random_test.rs new file mode 100644 index 0000000000..6c6f6a8c96 --- /dev/null +++ b/safekeeper/tests/random_test.rs @@ -0,0 +1,56 @@ +use rand::Rng; +use tracing::{info, warn}; + +use crate::walproposer_sim::{ + log::{init_logger, init_tracing_logger}, + simulation::{generate_network_opts, generate_schedule, TestConfig}, + simulation_logs::validate_events, +}; + +pub mod walproposer_sim; + +// Generates 2000 random seeds and runs a schedule for each of them. +// If you seed this test fail, please report the last seed to the +// @safekeeper team. +#[test] +fn test_random_schedules() -> anyhow::Result<()> { + let clock = init_logger(); + let mut config = TestConfig::new(Some(clock)); + + for _ in 0..2000 { + let seed: u64 = rand::thread_rng().gen(); + config.network = generate_network_opts(seed); + + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + test.run_schedule(&schedule).unwrap(); + validate_events(test.world.take_events()); + test.world.deallocate(); + } + + Ok(()) +} + +// After you found a seed that fails, you can insert this seed here +// and run the test to see the full debug output. +#[test] +fn test_one_schedule() -> anyhow::Result<()> { + let clock = init_tracing_logger(true); + let mut config = TestConfig::new(Some(clock)); + + let seed = 11047466935058776390; + config.network = generate_network_opts(seed); + info!("network: {:?}", config.network); + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + info!("schedule: {:?}", schedule); + test.run_schedule(&schedule).unwrap(); + validate_events(test.world.take_events()); + test.world.deallocate(); + + Ok(()) +} diff --git a/safekeeper/tests/simple_test.rs b/safekeeper/tests/simple_test.rs new file mode 100644 index 0000000000..0be9d0deef --- /dev/null +++ b/safekeeper/tests/simple_test.rs @@ -0,0 +1,45 @@ +use tracing::info; +use utils::lsn::Lsn; + +use crate::walproposer_sim::{log::init_logger, simulation::TestConfig}; + +pub mod walproposer_sim; + +// Check that first start of sync_safekeepers() returns 0/0 on empty safekeepers. +#[test] +fn sync_empty_safekeepers() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced (again) empty safekeepers at 0/0"); +} + +// Check that there are no panics when we are writing and streaming WAL to safekeepers. +#[test] +fn run_walproposer_generate_wal() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + // wait for walproposer to start + test.poll_for_duration(30); + + // just write some WAL + for _ in 0..100 { + wp.write_tx(1); + test.poll_for_duration(5); + } +} diff --git a/safekeeper/tests/walproposer_sim/block_storage.rs b/safekeeper/tests/walproposer_sim/block_storage.rs new file mode 100644 index 0000000000..468c02ad2f --- /dev/null +++ b/safekeeper/tests/walproposer_sim/block_storage.rs @@ -0,0 +1,57 @@ +use std::collections::HashMap; + +const BLOCK_SIZE: usize = 8192; + +/// A simple in-memory implementation of a block storage. Can be used to implement external +/// storage in tests. +pub struct BlockStorage { + blocks: HashMap, +} + +impl Default for BlockStorage { + fn default() -> Self { + Self::new() + } +} + +impl BlockStorage { + pub fn new() -> Self { + BlockStorage { + blocks: HashMap::new(), + } + } + + pub fn read(&self, pos: u64, buf: &mut [u8]) { + let mut buf_offset = 0; + let mut storage_pos = pos; + while buf_offset < buf.len() { + let block_id = storage_pos / BLOCK_SIZE as u64; + let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]); + let block_offset = storage_pos % BLOCK_SIZE as u64; + let block_len = BLOCK_SIZE as u64 - block_offset; + let buf_len = buf.len() - buf_offset; + let copy_len = std::cmp::min(block_len as usize, buf_len); + buf[buf_offset..buf_offset + copy_len] + .copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]); + buf_offset += copy_len; + storage_pos += copy_len as u64; + } + } + + pub fn write(&mut self, pos: u64, buf: &[u8]) { + let mut buf_offset = 0; + let mut storage_pos = pos; + while buf_offset < buf.len() { + let block_id = storage_pos / BLOCK_SIZE as u64; + let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]); + let block_offset = storage_pos % BLOCK_SIZE as u64; + let block_len = BLOCK_SIZE as u64 - block_offset; + let buf_len = buf.len() - buf_offset; + let copy_len = std::cmp::min(block_len as usize, buf_len); + block[block_offset as usize..block_offset as usize + copy_len] + .copy_from_slice(&buf[buf_offset..buf_offset + copy_len]); + buf_offset += copy_len; + storage_pos += copy_len as u64 + } + } +} diff --git a/safekeeper/tests/walproposer_sim/log.rs b/safekeeper/tests/walproposer_sim/log.rs new file mode 100644 index 0000000000..870f30de4f --- /dev/null +++ b/safekeeper/tests/walproposer_sim/log.rs @@ -0,0 +1,77 @@ +use std::{fmt, sync::Arc}; + +use desim::time::Timing; +use once_cell::sync::OnceCell; +use parking_lot::Mutex; +use tracing_subscriber::fmt::{format::Writer, time::FormatTime}; + +/// SimClock can be plugged into tracing logger to print simulation time. +#[derive(Clone)] +pub struct SimClock { + clock_ptr: Arc>>>, +} + +impl Default for SimClock { + fn default() -> Self { + SimClock { + clock_ptr: Arc::new(Mutex::new(None)), + } + } +} + +impl SimClock { + pub fn set_clock(&self, clock: Arc) { + *self.clock_ptr.lock() = Some(clock); + } +} + +impl FormatTime for SimClock { + fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result { + let clock = self.clock_ptr.lock(); + + if let Some(clock) = clock.as_ref() { + let now = clock.now(); + write!(w, "[{}]", now) + } else { + write!(w, "[?]") + } + } +} + +static LOGGING_DONE: OnceCell = OnceCell::new(); + +/// Returns ptr to clocks attached to tracing logger to update them when the +/// world is (re)created. +pub fn init_tracing_logger(debug_enabled: bool) -> SimClock { + LOGGING_DONE + .get_or_init(|| { + let clock = SimClock::default(); + let base_logger = tracing_subscriber::fmt() + .with_target(false) + // prefix log lines with simulated time timestamp + .with_timer(clock.clone()) + // .with_ansi(true) TODO + .with_max_level(match debug_enabled { + true => tracing::Level::DEBUG, + false => tracing::Level::WARN, + }) + .with_writer(std::io::stdout); + base_logger.init(); + + // logging::replace_panic_hook_with_tracing_panic_hook().forget(); + + if !debug_enabled { + std::panic::set_hook(Box::new(|_| {})); + } + + clock + }) + .clone() +} + +pub fn init_logger() -> SimClock { + // RUST_TRACEBACK envvar controls whether we print all logs or only warnings. + let debug_enabled = std::env::var("RUST_TRACEBACK").is_ok(); + + init_tracing_logger(debug_enabled) +} diff --git a/safekeeper/tests/walproposer_sim/mod.rs b/safekeeper/tests/walproposer_sim/mod.rs new file mode 100644 index 0000000000..ec560dcb3b --- /dev/null +++ b/safekeeper/tests/walproposer_sim/mod.rs @@ -0,0 +1,8 @@ +pub mod block_storage; +pub mod log; +pub mod safekeeper; +pub mod safekeeper_disk; +pub mod simulation; +pub mod simulation_logs; +pub mod walproposer_api; +pub mod walproposer_disk; diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs new file mode 100644 index 0000000000..1945b9d0cb --- /dev/null +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -0,0 +1,410 @@ +//! Safekeeper communication endpoint to WAL proposer (compute node). +//! Gets messages from the network, passes them down to consensus module and +//! sends replies back. + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use anyhow::{bail, Result}; +use bytes::{Bytes, BytesMut}; +use camino::Utf8PathBuf; +use desim::{ + executor::{self, PollSome}, + network::TCP, + node_os::NodeOs, + proto::{AnyMessage, NetEvent, NodeEvent}, +}; +use hyper::Uri; +use safekeeper::{ + safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION}, + state::TimelinePersistentState, + timeline::TimelineError, + wal_storage::Storage, + SafeKeeperConf, +}; +use tracing::{debug, info_span}; +use utils::{ + id::{NodeId, TenantId, TenantTimelineId, TimelineId}, + lsn::Lsn, +}; + +use super::safekeeper_disk::{DiskStateStorage, DiskWALStorage, SafekeeperDisk, TimelineDisk}; + +struct SharedState { + sk: SafeKeeper, + disk: Arc, +} + +struct GlobalMap { + timelines: HashMap, + conf: SafeKeeperConf, + disk: Arc, +} + +impl GlobalMap { + /// Restores global state from disk. + fn new(disk: Arc, conf: SafeKeeperConf) -> Result { + let mut timelines = HashMap::new(); + + for (&ttid, disk) in disk.timelines.lock().iter() { + debug!("loading timeline {}", ttid); + let state = disk.state.lock().clone(); + + if state.server.wal_seg_size == 0 { + bail!(TimelineError::UninitializedWalSegSize(ttid)); + } + + if state.server.pg_version == UNKNOWN_SERVER_VERSION { + bail!(TimelineError::UninitialinzedPgVersion(ttid)); + } + + if state.commit_lsn < state.local_start_lsn { + bail!( + "commit_lsn {} is higher than local_start_lsn {}", + state.commit_lsn, + state.local_start_lsn + ); + } + + let control_store = DiskStateStorage::new(disk.clone()); + let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?; + + let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?; + timelines.insert( + ttid, + SharedState { + sk, + disk: disk.clone(), + }, + ); + } + + Ok(Self { + timelines, + conf, + disk, + }) + } + + fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> { + if self.timelines.contains_key(&ttid) { + bail!("timeline {} already exists", ttid); + } + + debug!("creating new timeline {}", ttid); + + let commit_lsn = Lsn::INVALID; + let local_start_lsn = Lsn::INVALID; + + let state = + TimelinePersistentState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn); + + if state.server.wal_seg_size == 0 { + bail!(TimelineError::UninitializedWalSegSize(ttid)); + } + + if state.server.pg_version == UNKNOWN_SERVER_VERSION { + bail!(TimelineError::UninitialinzedPgVersion(ttid)); + } + + if state.commit_lsn < state.local_start_lsn { + bail!( + "commit_lsn {} is higher than local_start_lsn {}", + state.commit_lsn, + state.local_start_lsn + ); + } + + let disk_timeline = self.disk.put_state(&ttid, state); + let control_store = DiskStateStorage::new(disk_timeline.clone()); + let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?; + + let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?; + + self.timelines.insert( + ttid, + SharedState { + sk, + disk: disk_timeline, + }, + ); + Ok(()) + } + + fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState { + self.timelines.get_mut(ttid).expect("timeline must exist") + } + + fn has_tli(&self, ttid: &TenantTimelineId) -> bool { + self.timelines.contains_key(ttid) + } +} + +/// State of a single connection to walproposer. +struct ConnState { + tcp: TCP, + + greeting: bool, + ttid: TenantTimelineId, + flush_pending: bool, + + runtime: tokio::runtime::Runtime, +} + +pub fn run_server(os: NodeOs, disk: Arc) -> Result<()> { + let _enter = info_span!("safekeeper", id = os.id()).entered(); + debug!("started server"); + os.log_event("started;safekeeper".to_owned()); + let conf = SafeKeeperConf { + workdir: Utf8PathBuf::from("."), + my_id: NodeId(os.id() as u64), + listen_pg_addr: String::new(), + listen_http_addr: String::new(), + no_sync: false, + broker_endpoint: "/".parse::().unwrap(), + broker_keepalive_interval: Duration::from_secs(0), + heartbeat_timeout: Duration::from_secs(0), + remote_storage: None, + max_offloader_lag_bytes: 0, + wal_backup_enabled: false, + listen_pg_addr_tenant_only: None, + advertise_pg_addr: None, + availability_zone: None, + peer_recovery_enabled: false, + backup_parallel_jobs: 0, + pg_auth: None, + pg_tenant_only_auth: None, + http_auth: None, + current_thread_runtime: false, + }; + + let mut global = GlobalMap::new(disk, conf.clone())?; + let mut conns: HashMap = HashMap::new(); + + for (&_ttid, shared_state) in global.timelines.iter_mut() { + let flush_lsn = shared_state.sk.wal_store.flush_lsn(); + let commit_lsn = shared_state.sk.state.commit_lsn; + os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0)); + } + + let node_events = os.node_events(); + let mut epoll_vec: Vec> = vec![]; + let mut epoll_idx: Vec = vec![]; + + // TODO: batch events processing (multiple events per tick) + loop { + epoll_vec.clear(); + epoll_idx.clear(); + + // node events channel + epoll_vec.push(Box::new(node_events.clone())); + epoll_idx.push(0); + + // tcp connections + for conn in conns.values() { + epoll_vec.push(Box::new(conn.tcp.recv_chan())); + epoll_idx.push(conn.tcp.connection_id()); + } + + // waiting for the next message + let index = executor::epoll_chans(&epoll_vec, -1).unwrap(); + + if index == 0 { + // got a new connection + match node_events.must_recv() { + NodeEvent::Accept(tcp) => { + conns.insert( + tcp.connection_id(), + ConnState { + tcp, + greeting: false, + ttid: TenantTimelineId::empty(), + flush_pending: false, + runtime: tokio::runtime::Builder::new_current_thread().build()?, + }, + ); + } + NodeEvent::Internal(_) => unreachable!(), + } + continue; + } + + let connection_id = epoll_idx[index]; + let conn = conns.get_mut(&connection_id).unwrap(); + let mut next_event = Some(conn.tcp.recv_chan().must_recv()); + + loop { + let event = match next_event { + Some(event) => event, + None => break, + }; + + match event { + NetEvent::Message(msg) => { + let res = conn.process_any(msg, &mut global); + if res.is_err() { + debug!("conn {:?} error: {:#}", connection_id, res.unwrap_err()); + conns.remove(&connection_id); + break; + } + } + NetEvent::Closed => { + // TODO: remove from conns? + } + } + + next_event = conn.tcp.recv_chan().try_recv(); + } + + conns.retain(|_, conn| { + let res = conn.flush(&mut global); + if res.is_err() { + debug!("conn {:?} error: {:?}", conn.tcp, res); + } + res.is_ok() + }); + } +} + +impl ConnState { + /// Process a message from the network. It can be START_REPLICATION request or a valid ProposerAcceptorMessage message. + fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> { + if let AnyMessage::Bytes(copy_data) = any { + let repl_prefix = b"START_REPLICATION "; + if !self.greeting && copy_data.starts_with(repl_prefix) { + self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?; + bail!("finished processing START_REPLICATION") + } + + let msg = ProposerAcceptorMessage::parse(copy_data)?; + debug!("got msg: {:?}", msg); + self.process(msg, global) + } else { + bail!("unexpected message, expected AnyMessage::Bytes"); + } + } + + /// Process START_REPLICATION request. + fn process_start_replication( + &mut self, + copy_data: Bytes, + global: &mut GlobalMap, + ) -> Result<()> { + // format is " " + let str = String::from_utf8(copy_data.to_vec())?; + + let mut parts = str.split(' '); + let tenant_id = parts.next().unwrap().parse::()?; + let timeline_id = parts.next().unwrap().parse::()?; + let start_lsn = parts.next().unwrap().parse::()?; + let end_lsn = parts.next().unwrap().parse::()?; + + let ttid = TenantTimelineId::new(tenant_id, timeline_id); + let shared_state = global.get(&ttid); + + // read bytes from start_lsn to end_lsn + let mut buf = vec![0; (end_lsn - start_lsn) as usize]; + shared_state.disk.wal.lock().read(start_lsn, &mut buf); + + // send bytes to the client + self.tcp.send(AnyMessage::Bytes(Bytes::from(buf))); + Ok(()) + } + + /// Get or create a timeline. + fn init_timeline( + &mut self, + ttid: TenantTimelineId, + server_info: ServerInfo, + global: &mut GlobalMap, + ) -> Result<()> { + self.ttid = ttid; + if global.has_tli(&ttid) { + return Ok(()); + } + + global.create(ttid, server_info) + } + + /// Process a ProposerAcceptorMessage. + fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> { + if !self.greeting { + self.greeting = true; + + match msg { + ProposerAcceptorMessage::Greeting(ref greeting) => { + tracing::info!( + "start handshake with walproposer {:?} {:?}", + self.tcp, + greeting + ); + let server_info = ServerInfo { + pg_version: greeting.pg_version, + system_id: greeting.system_id, + wal_seg_size: greeting.wal_seg_size, + }; + let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id); + self.init_timeline(ttid, server_info, global)? + } + _ => { + bail!("unexpected message {msg:?} instead of greeting"); + } + } + } + + let tli = global.get(&self.ttid); + + match msg { + ProposerAcceptorMessage::AppendRequest(append_request) => { + self.flush_pending = true; + self.process_sk_msg( + tli, + &ProposerAcceptorMessage::NoFlushAppendRequest(append_request), + )?; + } + other => { + self.process_sk_msg(tli, &other)?; + } + } + + Ok(()) + } + + /// Process FlushWAL if needed. + fn flush(&mut self, global: &mut GlobalMap) -> Result<()> { + // TODO: try to add extra flushes in simulation, to verify that extra flushes don't break anything + if !self.flush_pending { + return Ok(()); + } + self.flush_pending = false; + let shared_state = global.get(&self.ttid); + self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL) + } + + /// Make safekeeper process a message and send a reply to the TCP + fn process_sk_msg( + &mut self, + shared_state: &mut SharedState, + msg: &ProposerAcceptorMessage, + ) -> Result<()> { + let mut reply = self.runtime.block_on(shared_state.sk.process_msg(msg))?; + if let Some(reply) = &mut reply { + // TODO: if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn + + let mut buf = BytesMut::with_capacity(128); + reply.serialize(&mut buf)?; + + self.tcp.send(AnyMessage::Bytes(buf.into())); + } + Ok(()) + } +} + +impl Drop for ConnState { + fn drop(&mut self) { + debug!("dropping conn: {:?}", self.tcp); + if !std::thread::panicking() { + self.tcp.close(); + } + // TODO: clean up non-fsynced WAL + } +} diff --git a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs new file mode 100644 index 0000000000..35bca325aa --- /dev/null +++ b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs @@ -0,0 +1,278 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::Mutex; +use safekeeper::state::TimelinePersistentState; +use utils::id::TenantTimelineId; + +use super::block_storage::BlockStorage; + +use std::{ops::Deref, time::Instant}; + +use anyhow::Result; +use bytes::{Buf, BytesMut}; +use futures::future::BoxFuture; +use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo}; +use safekeeper::{control_file, metrics::WalStorageMetrics, wal_storage}; +use tracing::{debug, info}; +use utils::lsn::Lsn; + +/// All safekeeper state that is usually saved to disk. +pub struct SafekeeperDisk { + pub timelines: Mutex>>, +} + +impl Default for SafekeeperDisk { + fn default() -> Self { + Self::new() + } +} + +impl SafekeeperDisk { + pub fn new() -> Self { + SafekeeperDisk { + timelines: Mutex::new(HashMap::new()), + } + } + + pub fn put_state( + &self, + ttid: &TenantTimelineId, + state: TimelinePersistentState, + ) -> Arc { + self.timelines + .lock() + .entry(*ttid) + .and_modify(|e| { + let mut mu = e.state.lock(); + *mu = state.clone(); + }) + .or_insert_with(|| { + Arc::new(TimelineDisk { + state: Mutex::new(state), + wal: Mutex::new(BlockStorage::new()), + }) + }) + .clone() + } +} + +/// Control file state and WAL storage. +pub struct TimelineDisk { + pub state: Mutex, + pub wal: Mutex, +} + +/// Implementation of `control_file::Storage` trait. +pub struct DiskStateStorage { + persisted_state: TimelinePersistentState, + disk: Arc, + last_persist_at: Instant, +} + +impl DiskStateStorage { + pub fn new(disk: Arc) -> Self { + let guard = disk.state.lock(); + let state = guard.clone(); + drop(guard); + DiskStateStorage { + persisted_state: state, + disk, + last_persist_at: Instant::now(), + } + } +} + +#[async_trait::async_trait] +impl control_file::Storage for DiskStateStorage { + /// Persist safekeeper state on disk and update internal state. + async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> { + self.persisted_state = s.clone(); + *self.disk.state.lock() = s.clone(); + Ok(()) + } + + /// Timestamp of last persist. + fn last_persist_at(&self) -> Instant { + // TODO: don't rely on it in tests + self.last_persist_at + } +} + +impl Deref for DiskStateStorage { + type Target = TimelinePersistentState; + + fn deref(&self) -> &Self::Target { + &self.persisted_state + } +} + +/// Implementation of `wal_storage::Storage` trait. +pub struct DiskWALStorage { + /// Written to disk, but possibly still in the cache and not fully persisted. + /// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record. + write_lsn: Lsn, + + /// The LSN of the last WAL record written to disk. Still can be not fully flushed. + write_record_lsn: Lsn, + + /// The LSN of the last WAL record flushed to disk. + flush_record_lsn: Lsn, + + /// Decoder is required for detecting boundaries of WAL records. + decoder: WalStreamDecoder, + + /// Bytes of WAL records that are not yet written to disk. + unflushed_bytes: BytesMut, + + /// Contains BlockStorage for WAL. + disk: Arc, +} + +impl DiskWALStorage { + pub fn new(disk: Arc, state: &TimelinePersistentState) -> Result { + let write_lsn = if state.commit_lsn == Lsn(0) { + Lsn(0) + } else { + Self::find_end_of_wal(disk.clone(), state.commit_lsn)? + }; + + let flush_lsn = write_lsn; + Ok(DiskWALStorage { + write_lsn, + write_record_lsn: flush_lsn, + flush_record_lsn: flush_lsn, + decoder: WalStreamDecoder::new(flush_lsn, 16), + unflushed_bytes: BytesMut::new(), + disk, + }) + } + + fn find_end_of_wal(disk: Arc, start_lsn: Lsn) -> Result { + let mut buf = [0; 8192]; + let mut pos = start_lsn.0; + let mut decoder = WalStreamDecoder::new(start_lsn, 16); + let mut result = start_lsn; + loop { + disk.wal.lock().read(pos, &mut buf); + pos += buf.len() as u64; + decoder.feed_bytes(&buf); + + loop { + match decoder.poll_decode() { + Ok(Some(record)) => result = record.0, + Err(e) => { + debug!( + "find_end_of_wal reached end at {:?}, decode error: {:?}", + result, e + ); + return Ok(result); + } + Ok(None) => break, // need more data + } + } + } + } +} + +#[async_trait::async_trait] +impl wal_storage::Storage for DiskWALStorage { + /// LSN of last durably stored WAL record. + fn flush_lsn(&self) -> Lsn { + self.flush_record_lsn + } + + /// Write piece of WAL from buf to disk, but not necessarily sync it. + async fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> { + if self.write_lsn != startpos { + panic!("write_wal called with wrong startpos"); + } + + self.unflushed_bytes.extend_from_slice(buf); + self.write_lsn += buf.len() as u64; + + if self.decoder.available() != startpos { + info!( + "restart decoder from {} to {}", + self.decoder.available(), + startpos, + ); + self.decoder = WalStreamDecoder::new(startpos, 16); + } + self.decoder.feed_bytes(buf); + loop { + match self.decoder.poll_decode()? { + None => break, // no full record yet + Some((lsn, _rec)) => { + self.write_record_lsn = lsn; + } + } + } + + Ok(()) + } + + /// Truncate WAL at specified LSN, which must be the end of WAL record. + async fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> { + if self.write_lsn != Lsn(0) && end_pos > self.write_lsn { + panic!( + "truncate_wal called on non-written WAL, write_lsn={}, end_pos={}", + self.write_lsn, end_pos + ); + } + + self.flush_wal().await?; + + // write zeroes to disk from end_pos until self.write_lsn + let buf = [0; 8192]; + let mut pos = end_pos.0; + while pos < self.write_lsn.0 { + self.disk.wal.lock().write(pos, &buf); + pos += buf.len() as u64; + } + + self.write_lsn = end_pos; + self.write_record_lsn = end_pos; + self.flush_record_lsn = end_pos; + self.unflushed_bytes.clear(); + self.decoder = WalStreamDecoder::new(end_pos, 16); + + Ok(()) + } + + /// Durably store WAL on disk, up to the last written WAL record. + async fn flush_wal(&mut self) -> Result<()> { + if self.flush_record_lsn == self.write_record_lsn { + // no need to do extra flush + return Ok(()); + } + + let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0; + + self.disk.wal.lock().write( + self.flush_record_lsn.0, + &self.unflushed_bytes[..num_bytes as usize], + ); + self.unflushed_bytes.advance(num_bytes as usize); + self.flush_record_lsn = self.write_record_lsn; + + Ok(()) + } + + /// Remove all segments <= given segno. Returns function doing that as we + /// want to perform it without timeline lock. + fn remove_up_to(&self, _segno_up_to: XLogSegNo) -> BoxFuture<'static, anyhow::Result<()>> { + Box::pin(async move { Ok(()) }) + } + + /// Release resources associated with the storage -- technically, close FDs. + /// Currently we don't remove timelines until restart (#3146), so need to + /// spare descriptors. This would be useful for temporary tli detach as + /// well. + fn close(&mut self) {} + + /// Get metrics for this timeline. + fn get_metrics(&self) -> WalStorageMetrics { + WalStorageMetrics::default() + } +} diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs new file mode 100644 index 0000000000..0d7aaf517b --- /dev/null +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -0,0 +1,436 @@ +use std::{cell::Cell, str::FromStr, sync::Arc}; + +use crate::walproposer_sim::{safekeeper::run_server, walproposer_api::SimulationApi}; +use desim::{ + executor::{self, ExternalHandle}, + node_os::NodeOs, + options::{Delay, NetworkOptions}, + proto::{AnyMessage, NodeEvent}, + world::Node, + world::World, +}; +use rand::{Rng, SeedableRng}; +use tracing::{debug, info_span, warn}; +use utils::{id::TenantTimelineId, lsn::Lsn}; +use walproposer::walproposer::{Config, Wrapper}; + +use super::{ + log::SimClock, safekeeper_disk::SafekeeperDisk, walproposer_api, + walproposer_disk::DiskWalProposer, +}; + +/// Simulated safekeeper node. +pub struct SafekeeperNode { + pub node: Arc, + pub id: u32, + pub disk: Arc, + pub thread: Cell, +} + +impl SafekeeperNode { + /// Create and start a safekeeper at the specified Node. + pub fn new(node: Arc) -> Self { + let disk = Arc::new(SafekeeperDisk::new()); + let thread = Cell::new(SafekeeperNode::launch(disk.clone(), node.clone())); + + Self { + id: node.id, + node, + disk, + thread, + } + } + + fn launch(disk: Arc, node: Arc) -> ExternalHandle { + // start the server thread + node.launch(move |os| { + run_server(os, disk).expect("server should finish without errors"); + }) + } + + /// Restart the safekeeper. + pub fn restart(&self) { + let new_thread = SafekeeperNode::launch(self.disk.clone(), self.node.clone()); + let old_thread = self.thread.replace(new_thread); + old_thread.crash_stop(); + } +} + +/// Simulated walproposer node. +pub struct WalProposer { + thread: ExternalHandle, + node: Arc, + disk: Arc, + sync_safekeepers: bool, +} + +impl WalProposer { + /// Generic start function for both modes. + fn start( + os: NodeOs, + disk: Arc, + ttid: TenantTimelineId, + addrs: Vec, + lsn: Option, + ) { + let sync_safekeepers = lsn.is_none(); + + let _enter = if sync_safekeepers { + info_span!("sync", started = executor::now()).entered() + } else { + info_span!("walproposer", started = executor::now()).entered() + }; + + os.log_event(format!("started;walproposer;{}", sync_safekeepers as i32)); + + let config = Config { + ttid, + safekeepers_list: addrs, + safekeeper_reconnect_timeout: 1000, + safekeeper_connection_timeout: 5000, + sync_safekeepers, + }; + let args = walproposer_api::Args { + os, + config: config.clone(), + disk, + redo_start_lsn: lsn, + }; + let api = SimulationApi::new(args); + let wp = Wrapper::new(Box::new(api), config); + wp.start(); + } + + /// Start walproposer in a sync_safekeepers mode. + pub fn launch_sync(ttid: TenantTimelineId, addrs: Vec, node: Arc) -> Self { + debug!("sync_safekeepers started at node {}", node.id); + let disk = DiskWalProposer::new(); + let disk_wp = disk.clone(); + + // start the client thread + let handle = node.launch(move |os| { + WalProposer::start(os, disk_wp, ttid, addrs, None); + }); + + Self { + thread: handle, + node, + disk, + sync_safekeepers: true, + } + } + + /// Start walproposer in a normal mode. + pub fn launch_walproposer( + ttid: TenantTimelineId, + addrs: Vec, + node: Arc, + lsn: Lsn, + ) -> Self { + debug!("walproposer started at node {}", node.id); + let disk = DiskWalProposer::new(); + disk.lock().reset_to(lsn); + let disk_wp = disk.clone(); + + // start the client thread + let handle = node.launch(move |os| { + WalProposer::start(os, disk_wp, ttid, addrs, Some(lsn)); + }); + + Self { + thread: handle, + node, + disk, + sync_safekeepers: false, + } + } + + pub fn write_tx(&mut self, cnt: usize) { + let start_lsn = self.disk.lock().flush_rec_ptr(); + + for _ in 0..cnt { + self.disk + .lock() + .insert_logical_message("prefix", b"message") + .expect("failed to generate logical message"); + } + + let end_lsn = self.disk.lock().flush_rec_ptr(); + + // log event + self.node + .log_event(format!("write_wal;{};{};{}", start_lsn.0, end_lsn.0, cnt)); + + // now we need to set "Latch" in walproposer + self.node + .node_events() + .send(NodeEvent::Internal(AnyMessage::Just32(0))); + } + + pub fn stop(&self) { + self.thread.crash_stop(); + } +} + +/// Holds basic simulation settings, such as network options. +pub struct TestConfig { + pub network: NetworkOptions, + pub timeout: u64, + pub clock: Option, +} + +impl TestConfig { + /// Create a new TestConfig with default settings. + pub fn new(clock: Option) -> Self { + Self { + network: NetworkOptions { + keepalive_timeout: Some(2000), + connect_delay: Delay { + min: 1, + max: 5, + fail_prob: 0.0, + }, + send_delay: Delay { + min: 1, + max: 5, + fail_prob: 0.0, + }, + }, + timeout: 1_000 * 10, + clock, + } + } + + /// Start a new simulation with the specified seed. + pub fn start(&self, seed: u64) -> Test { + let world = Arc::new(World::new(seed, Arc::new(self.network.clone()))); + + if let Some(clock) = &self.clock { + clock.set_clock(world.clock()); + } + + let servers = [ + SafekeeperNode::new(world.new_node()), + SafekeeperNode::new(world.new_node()), + SafekeeperNode::new(world.new_node()), + ]; + + let server_ids = [servers[0].id, servers[1].id, servers[2].id]; + let safekeepers_addrs = server_ids.map(|id| format!("node:{}", id)).to_vec(); + + let ttid = TenantTimelineId::generate(); + + Test { + world, + servers, + sk_list: safekeepers_addrs, + ttid, + timeout: self.timeout, + } + } +} + +/// Holds simulation state. +pub struct Test { + pub world: Arc, + pub servers: [SafekeeperNode; 3], + pub sk_list: Vec, + pub ttid: TenantTimelineId, + pub timeout: u64, +} + +impl Test { + /// Start a sync_safekeepers thread and wait for it to finish. + pub fn sync_safekeepers(&self) -> anyhow::Result { + let wp = self.launch_sync_safekeepers(); + + // poll until exit or timeout + let time_limit = self.timeout; + while self.world.step() && self.world.now() < time_limit && !wp.thread.is_finished() {} + + if !wp.thread.is_finished() { + anyhow::bail!("timeout or idle stuck"); + } + + let res = wp.thread.result(); + if res.0 != 0 { + anyhow::bail!("non-zero exitcode: {:?}", res); + } + let lsn = Lsn::from_str(&res.1)?; + Ok(lsn) + } + + /// Spawn a new sync_safekeepers thread. + pub fn launch_sync_safekeepers(&self) -> WalProposer { + WalProposer::launch_sync(self.ttid, self.sk_list.clone(), self.world.new_node()) + } + + /// Spawn a new walproposer thread. + pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer { + let lsn = if lsn.0 == 0 { + // usual LSN after basebackup + Lsn(21623024) + } else { + lsn + }; + + WalProposer::launch_walproposer(self.ttid, self.sk_list.clone(), self.world.new_node(), lsn) + } + + /// Execute the simulation for the specified duration. + pub fn poll_for_duration(&self, duration: u64) { + let time_limit = std::cmp::min(self.world.now() + duration, self.timeout); + while self.world.step() && self.world.now() < time_limit {} + } + + /// Execute the simulation together with events defined in some schedule. + pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> { + // scheduling empty events so that world will stop in those points + { + let clock = self.world.clock(); + + let now = self.world.now(); + for (time, _) in schedule { + if *time < now { + continue; + } + clock.schedule_fake(*time - now); + } + } + + let mut wp = self.launch_sync_safekeepers(); + + let mut skipped_tx = 0; + let mut started_tx = 0; + + let mut schedule_ptr = 0; + + loop { + if wp.sync_safekeepers && wp.thread.is_finished() { + let res = wp.thread.result(); + if res.0 != 0 { + warn!("sync non-zero exitcode: {:?}", res); + debug!("restarting sync_safekeepers"); + // restart the sync_safekeepers + wp = self.launch_sync_safekeepers(); + continue; + } + let lsn = Lsn::from_str(&res.1)?; + debug!("sync_safekeepers finished at LSN {}", lsn); + wp = self.launch_walproposer(lsn); + debug!("walproposer started at thread {}", wp.thread.id()); + } + + let now = self.world.now(); + while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now { + if now != schedule[schedule_ptr].0 { + warn!("skipped event {:?} at {}", schedule[schedule_ptr], now); + } + + let action = &schedule[schedule_ptr].1; + match action { + TestAction::WriteTx(size) => { + if !wp.sync_safekeepers && !wp.thread.is_finished() { + started_tx += *size; + wp.write_tx(*size); + debug!("written {} transactions", size); + } else { + skipped_tx += size; + debug!("skipped {} transactions", size); + } + } + TestAction::RestartSafekeeper(id) => { + debug!("restarting safekeeper {}", id); + self.servers[*id].restart(); + } + TestAction::RestartWalProposer => { + debug!("restarting sync_safekeepers"); + wp.stop(); + wp = self.launch_sync_safekeepers(); + } + } + schedule_ptr += 1; + } + + if schedule_ptr == schedule.len() { + break; + } + let next_event_time = schedule[schedule_ptr].0; + + // poll until the next event + if wp.thread.is_finished() { + while self.world.step() && self.world.now() < next_event_time {} + } else { + while self.world.step() + && self.world.now() < next_event_time + && !wp.thread.is_finished() + {} + } + } + + debug!( + "finished schedule, total steps: {}", + self.world.get_thread_step_count() + ); + debug!("skipped_tx: {}", skipped_tx); + debug!("started_tx: {}", started_tx); + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub enum TestAction { + WriteTx(usize), + RestartSafekeeper(usize), + RestartWalProposer, +} + +pub type Schedule = Vec<(u64, TestAction)>; + +pub fn generate_schedule(seed: u64) -> Schedule { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut schedule = Vec::new(); + let mut time = 0; + + let cnt = rng.gen_range(1..100); + + for _ in 0..cnt { + time += rng.gen_range(0..500); + let action = match rng.gen_range(0..3) { + 0 => TestAction::WriteTx(rng.gen_range(1..10)), + 1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)), + 2 => TestAction::RestartWalProposer, + _ => unreachable!(), + }; + schedule.push((time, action)); + } + + schedule +} + +pub fn generate_network_opts(seed: u64) -> NetworkOptions { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + let timeout = rng.gen_range(100..2000); + let max_delay = rng.gen_range(1..2 * timeout); + let min_delay = rng.gen_range(1..=max_delay); + + let max_fail_prob = rng.gen_range(0.0..0.9); + let connect_fail_prob = rng.gen_range(0.0..max_fail_prob); + let send_fail_prob = rng.gen_range(0.0..connect_fail_prob); + + NetworkOptions { + keepalive_timeout: Some(timeout), + connect_delay: Delay { + min: min_delay, + max: max_delay, + fail_prob: connect_fail_prob, + }, + send_delay: Delay { + min: min_delay, + max: max_delay, + fail_prob: send_fail_prob, + }, + } +} diff --git a/safekeeper/tests/walproposer_sim/simulation_logs.rs b/safekeeper/tests/walproposer_sim/simulation_logs.rs new file mode 100644 index 0000000000..38885e5dd0 --- /dev/null +++ b/safekeeper/tests/walproposer_sim/simulation_logs.rs @@ -0,0 +1,187 @@ +use desim::proto::SimEvent; +use tracing::debug; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum NodeKind { + Unknown, + Safekeeper, + WalProposer, +} + +impl Default for NodeKind { + fn default() -> Self { + Self::Unknown + } +} + +/// Simulation state of walproposer/safekeeper, derived from the simulation logs. +#[derive(Clone, Debug, Default)] +struct NodeInfo { + kind: NodeKind, + + // walproposer + is_sync: bool, + term: u64, + epoch_lsn: u64, + + // safekeeper + commit_lsn: u64, + flush_lsn: u64, +} + +impl NodeInfo { + fn init_kind(&mut self, kind: NodeKind) { + if self.kind == NodeKind::Unknown { + self.kind = kind; + } else { + assert!(self.kind == kind); + } + } + + fn started(&mut self, data: &str) { + let mut parts = data.split(';'); + assert!(parts.next().unwrap() == "started"); + match parts.next().unwrap() { + "safekeeper" => { + self.init_kind(NodeKind::Safekeeper); + } + "walproposer" => { + self.init_kind(NodeKind::WalProposer); + let is_sync: u8 = parts.next().unwrap().parse().unwrap(); + self.is_sync = is_sync != 0; + } + _ => unreachable!(), + } + } +} + +/// Global state of the simulation, derived from the simulation logs. +#[derive(Debug, Default)] +struct GlobalState { + nodes: Vec, + commit_lsn: u64, + write_lsn: u64, + max_write_lsn: u64, + + written_wal: u64, + written_records: u64, +} + +impl GlobalState { + fn new() -> Self { + Default::default() + } + + fn get(&mut self, id: u32) -> &mut NodeInfo { + let id = id as usize; + if id >= self.nodes.len() { + self.nodes.resize(id + 1, NodeInfo::default()); + } + &mut self.nodes[id] + } +} + +/// Try to find inconsistencies in the simulation log. +pub fn validate_events(events: Vec) { + const INITDB_LSN: u64 = 21623024; + + let hook = std::panic::take_hook(); + scopeguard::defer_on_success! { + std::panic::set_hook(hook); + }; + + let mut state = GlobalState::new(); + state.max_write_lsn = INITDB_LSN; + + for event in events { + debug!("{:?}", event); + + let node = state.get(event.node); + if event.data.starts_with("started;") { + node.started(&event.data); + continue; + } + assert!(node.kind != NodeKind::Unknown); + + // drop reference to unlock state + let mut node = node.clone(); + + let mut parts = event.data.split(';'); + match node.kind { + NodeKind::Safekeeper => match parts.next().unwrap() { + "tli_loaded" => { + let flush_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let commit_lsn: u64 = parts.next().unwrap().parse().unwrap(); + node.flush_lsn = flush_lsn; + node.commit_lsn = commit_lsn; + } + _ => unreachable!(), + }, + NodeKind::WalProposer => { + match parts.next().unwrap() { + "prop_elected" => { + let prop_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let prop_term: u64 = parts.next().unwrap().parse().unwrap(); + let prev_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let prev_term: u64 = parts.next().unwrap().parse().unwrap(); + + assert!(prop_lsn >= prev_lsn); + assert!(prop_term >= prev_term); + + assert!(prop_lsn >= state.commit_lsn); + + if prop_lsn > state.write_lsn { + assert!(prop_lsn <= state.max_write_lsn); + debug!( + "moving write_lsn up from {} to {}", + state.write_lsn, prop_lsn + ); + state.write_lsn = prop_lsn; + } + if prop_lsn < state.write_lsn { + debug!( + "moving write_lsn down from {} to {}", + state.write_lsn, prop_lsn + ); + state.write_lsn = prop_lsn; + } + + node.epoch_lsn = prop_lsn; + node.term = prop_term; + } + "write_wal" => { + assert!(!node.is_sync); + let start_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let end_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let cnt: u64 = parts.next().unwrap().parse().unwrap(); + + let size = end_lsn - start_lsn; + state.written_wal += size; + state.written_records += cnt; + + // TODO: If we allow writing WAL before winning the election + + assert!(start_lsn >= state.commit_lsn); + assert!(end_lsn >= start_lsn); + // assert!(start_lsn == state.write_lsn); + state.write_lsn = end_lsn; + + if end_lsn > state.max_write_lsn { + state.max_write_lsn = end_lsn; + } + } + "commit_lsn" => { + let lsn: u64 = parts.next().unwrap().parse().unwrap(); + assert!(lsn >= state.commit_lsn); + state.commit_lsn = lsn; + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + // update the node in the state struct + *state.get(event.node) = node; + } +} diff --git a/safekeeper/tests/walproposer_sim/walproposer_api.rs b/safekeeper/tests/walproposer_sim/walproposer_api.rs new file mode 100644 index 0000000000..746cac019e --- /dev/null +++ b/safekeeper/tests/walproposer_sim/walproposer_api.rs @@ -0,0 +1,676 @@ +use std::{ + cell::{RefCell, RefMut, UnsafeCell}, + ffi::CStr, + sync::Arc, +}; + +use bytes::Bytes; +use desim::{ + executor::{self, PollSome}, + network::TCP, + node_os::NodeOs, + proto::{AnyMessage, NetEvent, NodeEvent}, + world::NodeId, +}; +use tracing::debug; +use utils::lsn::Lsn; +use walproposer::{ + api_bindings::Level, + bindings::{ + pg_atomic_uint64, NeonWALReadResult, PageserverFeedback, SafekeeperStateDesiredEvents, + WL_SOCKET_READABLE, WL_SOCKET_WRITEABLE, + }, + walproposer::{ApiImpl, Config}, +}; + +use super::walproposer_disk::DiskWalProposer; + +/// Special state for each wp->sk connection. +struct SafekeeperConn { + host: String, + port: String, + node_id: NodeId, + // socket is Some(..) equals to connection is established + socket: Option, + // connection is in progress + is_connecting: bool, + // START_WAL_PUSH is in progress + is_start_wal_push: bool, + // pointer to Safekeeper in walproposer for callbacks + raw_ptr: *mut walproposer::bindings::Safekeeper, +} + +impl SafekeeperConn { + pub fn new(host: String, port: String) -> Self { + // port number is the same as NodeId + let port_num = port.parse::().unwrap(); + Self { + host, + port, + node_id: port_num, + socket: None, + is_connecting: false, + is_start_wal_push: false, + raw_ptr: std::ptr::null_mut(), + } + } +} + +/// Simulation version of a postgres WaitEventSet. At pos 0 there is always +/// a special NodeEvents channel, which is used as a latch. +struct EventSet { + os: NodeOs, + // all pollable channels, 0 is always NodeEvent channel + chans: Vec>, + // 0 is always nullptr + sk_ptrs: Vec<*mut walproposer::bindings::Safekeeper>, + // event mask for each channel + masks: Vec, +} + +impl EventSet { + pub fn new(os: NodeOs) -> Self { + let node_events = os.node_events(); + Self { + os, + chans: vec![Box::new(node_events)], + sk_ptrs: vec![std::ptr::null_mut()], + masks: vec![WL_SOCKET_READABLE], + } + } + + /// Leaves all readable channels at the beginning of the array. + fn sort_readable(&mut self) -> usize { + let mut cnt = 1; + for i in 1..self.chans.len() { + if self.masks[i] & WL_SOCKET_READABLE != 0 { + self.chans.swap(i, cnt); + self.sk_ptrs.swap(i, cnt); + self.masks.swap(i, cnt); + cnt += 1; + } + } + cnt + } + + fn update_event_set(&mut self, conn: &SafekeeperConn, event_mask: u32) { + let index = self + .sk_ptrs + .iter() + .position(|&ptr| ptr == conn.raw_ptr) + .expect("safekeeper should exist in event set"); + self.masks[index] = event_mask; + } + + fn add_safekeeper(&mut self, sk: &SafekeeperConn, event_mask: u32) { + for ptr in self.sk_ptrs.iter() { + assert!(*ptr != sk.raw_ptr); + } + + self.chans.push(Box::new( + sk.socket + .as_ref() + .expect("socket should not be closed") + .recv_chan(), + )); + self.sk_ptrs.push(sk.raw_ptr); + self.masks.push(event_mask); + } + + fn remove_safekeeper(&mut self, sk: &SafekeeperConn) { + let index = self.sk_ptrs.iter().position(|&ptr| ptr == sk.raw_ptr); + if index.is_none() { + debug!("remove_safekeeper: sk={:?} not found", sk.raw_ptr); + return; + } + let index = index.unwrap(); + + self.chans.remove(index); + self.sk_ptrs.remove(index); + self.masks.remove(index); + + // to simulate the actual behaviour + self.refresh_event_set(); + } + + /// Updates all masks to match the result of a SafekeeperStateDesiredEvents. + fn refresh_event_set(&mut self) { + for (i, mask) in self.masks.iter_mut().enumerate() { + if i == 0 { + continue; + } + + let mut mask_sk: u32 = 0; + let mut mask_nwr: u32 = 0; + unsafe { SafekeeperStateDesiredEvents(self.sk_ptrs[i], &mut mask_sk, &mut mask_nwr) }; + + if mask_sk != *mask { + debug!( + "refresh_event_set: sk={:?}, old_mask={:#b}, new_mask={:#b}", + self.sk_ptrs[i], *mask, mask_sk + ); + *mask = mask_sk; + } + } + } + + /// Wait for events on all channels. + fn wait(&mut self, timeout_millis: i64) -> walproposer::walproposer::WaitResult { + // all channels are always writeable + for (i, mask) in self.masks.iter().enumerate() { + if *mask & WL_SOCKET_WRITEABLE != 0 { + return walproposer::walproposer::WaitResult::Network( + self.sk_ptrs[i], + WL_SOCKET_WRITEABLE, + ); + } + } + + let cnt = self.sort_readable(); + + let slice = &self.chans[0..cnt]; + match executor::epoll_chans(slice, timeout_millis) { + None => walproposer::walproposer::WaitResult::Timeout, + Some(0) => { + let msg = self.os.node_events().must_recv(); + match msg { + NodeEvent::Internal(AnyMessage::Just32(0)) => { + // got a notification about new WAL available + } + NodeEvent::Internal(_) => unreachable!(), + NodeEvent::Accept(_) => unreachable!(), + } + walproposer::walproposer::WaitResult::Latch + } + Some(index) => walproposer::walproposer::WaitResult::Network( + self.sk_ptrs[index], + WL_SOCKET_READABLE, + ), + } + } +} + +/// This struct handles all calls from walproposer into walproposer_api. +pub struct SimulationApi { + os: NodeOs, + safekeepers: RefCell>, + disk: Arc, + redo_start_lsn: Option, + shmem: UnsafeCell, + config: Config, + event_set: RefCell>, +} + +pub struct Args { + pub os: NodeOs, + pub config: Config, + pub disk: Arc, + pub redo_start_lsn: Option, +} + +impl SimulationApi { + pub fn new(args: Args) -> Self { + // initialize connection state for each safekeeper + let sk_conns = args + .config + .safekeepers_list + .iter() + .map(|s| { + SafekeeperConn::new( + s.split(':').next().unwrap().to_string(), + s.split(':').nth(1).unwrap().to_string(), + ) + }) + .collect::>(); + + Self { + os: args.os, + safekeepers: RefCell::new(sk_conns), + disk: args.disk, + redo_start_lsn: args.redo_start_lsn, + shmem: UnsafeCell::new(walproposer::bindings::WalproposerShmemState { + mutex: 0, + feedback: PageserverFeedback { + currentClusterSize: 0, + last_received_lsn: 0, + disk_consistent_lsn: 0, + remote_consistent_lsn: 0, + replytime: 0, + }, + mineLastElectedTerm: 0, + backpressureThrottlingTime: pg_atomic_uint64 { value: 0 }, + }), + config: args.config, + event_set: RefCell::new(None), + } + } + + /// Get SafekeeperConn for the given Safekeeper. + fn get_conn(&self, sk: &mut walproposer::bindings::Safekeeper) -> RefMut<'_, SafekeeperConn> { + let sk_port = unsafe { CStr::from_ptr(sk.port).to_str().unwrap() }; + let state = self.safekeepers.borrow_mut(); + RefMut::map(state, |v| { + v.iter_mut() + .find(|conn| conn.port == sk_port) + .expect("safekeeper conn not found by port") + }) + } +} + +impl ApiImpl for SimulationApi { + fn get_current_timestamp(&self) -> i64 { + debug!("get_current_timestamp"); + // PG TimestampTZ is microseconds, but simulation unit is assumed to be + // milliseconds, so add 10^3 + self.os.now() as i64 * 1000 + } + + fn conn_status( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerConnStatusType { + debug!("conn_status"); + // break the connection with a 10% chance + if self.os.random(100) < 10 { + walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_BAD + } else { + walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_OK + } + } + + fn conn_connect_start(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!("conn_connect_start"); + let mut conn = self.get_conn(sk); + + assert!(conn.socket.is_none()); + let socket = self.os.open_tcp(conn.node_id); + conn.socket = Some(socket); + conn.raw_ptr = sk; + conn.is_connecting = true; + } + + fn conn_connect_poll( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerConnectPollStatusType { + debug!("conn_connect_poll"); + // TODO: break the connection here + walproposer::bindings::WalProposerConnectPollStatusType_WP_CONN_POLLING_OK + } + + fn conn_send_query(&self, sk: &mut walproposer::bindings::Safekeeper, query: &str) -> bool { + debug!("conn_send_query: {}", query); + self.get_conn(sk).is_start_wal_push = true; + true + } + + fn conn_get_query_result( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerExecStatusType { + debug!("conn_get_query_result"); + // TODO: break the connection here + walproposer::bindings::WalProposerExecStatusType_WP_EXEC_SUCCESS_COPYBOTH + } + + fn conn_async_read( + &self, + sk: &mut walproposer::bindings::Safekeeper, + vec: &mut Vec, + ) -> walproposer::bindings::PGAsyncReadResult { + debug!("conn_async_read"); + let mut conn = self.get_conn(sk); + + let socket = if let Some(socket) = conn.socket.as_mut() { + socket + } else { + // socket is already closed + return walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL; + }; + + let msg = socket.recv_chan().try_recv(); + + match msg { + None => { + // no message is ready + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_TRY_AGAIN + } + Some(NetEvent::Closed) => { + // connection is closed + debug!("conn_async_read: connection is closed"); + conn.socket = None; + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL + } + Some(NetEvent::Message(msg)) => { + // got a message + let b = match msg { + desim::proto::AnyMessage::Bytes(b) => b, + _ => unreachable!(), + }; + vec.extend_from_slice(&b); + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS + } + } + } + + fn conn_blocking_write(&self, sk: &mut walproposer::bindings::Safekeeper, buf: &[u8]) -> bool { + let mut conn = self.get_conn(sk); + debug!("conn_blocking_write to {}: {:?}", conn.node_id, buf); + let socket = conn.socket.as_mut().unwrap(); + socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf))); + true + } + + fn conn_async_write( + &self, + sk: &mut walproposer::bindings::Safekeeper, + buf: &[u8], + ) -> walproposer::bindings::PGAsyncWriteResult { + let mut conn = self.get_conn(sk); + debug!("conn_async_write to {}: {:?}", conn.node_id, buf); + if let Some(socket) = conn.socket.as_mut() { + socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf))); + } else { + // connection is already closed + debug!("conn_async_write: writing to a closed socket!"); + // TODO: maybe we should return error here? + } + walproposer::bindings::PGAsyncWriteResult_PG_ASYNC_WRITE_SUCCESS + } + + fn wal_reader_allocate(&self, _: &mut walproposer::bindings::Safekeeper) -> NeonWALReadResult { + debug!("wal_reader_allocate"); + walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS + } + + fn wal_read( + &self, + _sk: &mut walproposer::bindings::Safekeeper, + buf: &mut [u8], + startpos: u64, + ) -> NeonWALReadResult { + self.disk.lock().read(startpos, buf); + walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS + } + + fn init_event_set(&self, _: &mut walproposer::bindings::WalProposer) { + debug!("init_event_set"); + let new_event_set = EventSet::new(self.os.clone()); + let old_event_set = self.event_set.replace(Some(new_event_set)); + assert!(old_event_set.is_none()); + } + + fn update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper, event_mask: u32) { + debug!( + "update_event_set, sk={:?}, events_mask={:#b}", + sk as *mut walproposer::bindings::Safekeeper, event_mask + ); + let conn = self.get_conn(sk); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .update_event_set(&conn, event_mask); + } + + fn add_safekeeper_event_set( + &self, + sk: &mut walproposer::bindings::Safekeeper, + event_mask: u32, + ) { + debug!( + "add_safekeeper_event_set, sk={:?}, events_mask={:#b}", + sk as *mut walproposer::bindings::Safekeeper, event_mask + ); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .add_safekeeper(&self.get_conn(sk), event_mask); + } + + fn rm_safekeeper_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!( + "rm_safekeeper_event_set, sk={:?}", + sk as *mut walproposer::bindings::Safekeeper, + ); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .remove_safekeeper(&self.get_conn(sk)); + } + + fn active_state_update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!("active_state_update_event_set"); + + assert!(sk.state == walproposer::bindings::SafekeeperState_SS_ACTIVE); + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .refresh_event_set(); + } + + fn wal_reader_events(&self, _sk: &mut walproposer::bindings::Safekeeper) -> u32 { + 0 + } + + fn wait_event_set( + &self, + _: &mut walproposer::bindings::WalProposer, + timeout_millis: i64, + ) -> walproposer::walproposer::WaitResult { + // TODO: handle multiple stages as part of the simulation (e.g. connect, start_wal_push, etc) + let mut conns = self.safekeepers.borrow_mut(); + for conn in conns.iter_mut() { + if conn.socket.is_some() && conn.is_connecting { + conn.is_connecting = false; + debug!("wait_event_set, connecting to {}:{}", conn.host, conn.port); + return walproposer::walproposer::WaitResult::Network( + conn.raw_ptr, + WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE, + ); + } + if conn.socket.is_some() && conn.is_start_wal_push { + conn.is_start_wal_push = false; + debug!( + "wait_event_set, start wal push to {}:{}", + conn.host, conn.port + ); + return walproposer::walproposer::WaitResult::Network( + conn.raw_ptr, + WL_SOCKET_READABLE, + ); + } + } + drop(conns); + + let res = self + .event_set + .borrow_mut() + .as_mut() + .unwrap() + .wait(timeout_millis); + + debug!( + "wait_event_set, timeout_millis={}, res={:?}", + timeout_millis, res, + ); + res + } + + fn strong_random(&self, buf: &mut [u8]) -> bool { + debug!("strong_random"); + buf.fill(0); + true + } + + fn finish_sync_safekeepers(&self, lsn: u64) { + debug!("finish_sync_safekeepers, lsn={}", lsn); + executor::exit(0, Lsn(lsn).to_string()); + } + + fn log_internal(&self, _wp: &mut walproposer::bindings::WalProposer, level: Level, msg: &str) { + debug!("wp_log[{}] {}", level, msg); + if level == Level::Fatal || level == Level::Panic { + if msg.contains("rejects our connection request with term") { + // collected quorum with lower term, then got rejected by next connected safekeeper + executor::exit(1, msg.to_owned()); + } + if msg.contains("collected propEpochStartLsn") && msg.contains(", but basebackup LSN ") + { + // sync-safekeepers collected wrong quorum, walproposer collected another quorum + executor::exit(1, msg.to_owned()); + } + if msg.contains("failed to download WAL for logical replicaiton") { + // Recovery connection broken and recovery was failed + executor::exit(1, msg.to_owned()); + } + if msg.contains("missing majority of votes, collected") { + // Voting bug when safekeeper disconnects after voting + executor::exit(1, msg.to_owned()); + } + panic!("unknown FATAL error from walproposer: {}", msg); + } + } + + fn after_election(&self, wp: &mut walproposer::bindings::WalProposer) { + let prop_lsn = wp.propEpochStartLsn; + let prop_term = wp.propTerm; + + let mut prev_lsn: u64 = 0; + let mut prev_term: u64 = 0; + + unsafe { + let history = wp.propTermHistory.entries; + let len = wp.propTermHistory.n_entries as usize; + if len > 1 { + let entry = *history.wrapping_add(len - 2); + prev_lsn = entry.lsn; + prev_term = entry.term; + } + } + + let msg = format!( + "prop_elected;{};{};{};{}", + prop_lsn, prop_term, prev_lsn, prev_term + ); + + debug!(msg); + self.os.log_event(msg); + } + + fn get_redo_start_lsn(&self) -> u64 { + debug!("get_redo_start_lsn -> {:?}", self.redo_start_lsn); + self.redo_start_lsn.expect("redo_start_lsn is not set").0 + } + + fn get_shmem_state(&self) -> *mut walproposer::bindings::WalproposerShmemState { + self.shmem.get() + } + + fn start_streaming( + &self, + startpos: u64, + callback: &walproposer::walproposer::StreamingCallback, + ) { + let disk = &self.disk; + let disk_lsn = disk.lock().flush_rec_ptr().0; + debug!("start_streaming at {} (disk_lsn={})", startpos, disk_lsn); + if startpos < disk_lsn { + debug!("startpos < disk_lsn, it means we wrote some transaction even before streaming started"); + } + assert!(startpos <= disk_lsn); + let mut broadcasted = Lsn(startpos); + + loop { + let available = disk.lock().flush_rec_ptr(); + assert!(available >= broadcasted); + callback.broadcast(broadcasted, available); + broadcasted = available; + callback.poll(); + } + } + + fn process_safekeeper_feedback( + &self, + wp: &mut walproposer::bindings::WalProposer, + commit_lsn: u64, + ) { + debug!("process_safekeeper_feedback, commit_lsn={}", commit_lsn); + if commit_lsn > wp.lastSentCommitLsn { + self.os.log_event(format!("commit_lsn;{}", commit_lsn)); + } + } + + fn get_flush_rec_ptr(&self) -> u64 { + let lsn = self.disk.lock().flush_rec_ptr(); + debug!("get_flush_rec_ptr: {}", lsn); + lsn.0 + } + + fn recovery_download( + &self, + wp: &mut walproposer::bindings::WalProposer, + sk: &mut walproposer::bindings::Safekeeper, + ) -> bool { + let mut startpos = wp.truncateLsn; + let endpos = wp.propEpochStartLsn; + + if startpos == endpos { + debug!("recovery_download: nothing to download"); + return true; + } + + debug!("recovery_download from {} to {}", startpos, endpos,); + + let replication_prompt = format!( + "START_REPLICATION {} {} {} {}", + self.config.ttid.tenant_id, self.config.ttid.timeline_id, startpos, endpos, + ); + let async_conn = self.get_conn(sk); + + let conn = self.os.open_tcp(async_conn.node_id); + conn.send(desim::proto::AnyMessage::Bytes(replication_prompt.into())); + + let chan = conn.recv_chan(); + while startpos < endpos { + let event = chan.recv(); + match event { + NetEvent::Closed => { + debug!("connection closed in recovery"); + break; + } + NetEvent::Message(AnyMessage::Bytes(b)) => { + debug!("got recovery bytes from safekeeper"); + self.disk.lock().write(startpos, &b); + startpos += b.len() as u64; + } + NetEvent::Message(_) => unreachable!(), + } + } + + debug!("recovery finished at {}", startpos); + + startpos == endpos + } + + fn conn_finish(&self, sk: &mut walproposer::bindings::Safekeeper) { + let mut conn = self.get_conn(sk); + debug!("conn_finish to {}", conn.node_id); + if let Some(socket) = conn.socket.as_mut() { + socket.close(); + } else { + // connection is already closed + } + conn.socket = None; + } + + fn conn_error_message(&self, _sk: &mut walproposer::bindings::Safekeeper) -> String { + "connection is closed, probably".into() + } +} diff --git a/safekeeper/tests/walproposer_sim/walproposer_disk.rs b/safekeeper/tests/walproposer_sim/walproposer_disk.rs new file mode 100644 index 0000000000..aa329bd2f0 --- /dev/null +++ b/safekeeper/tests/walproposer_sim/walproposer_disk.rs @@ -0,0 +1,314 @@ +use std::{ffi::CString, sync::Arc}; + +use byteorder::{LittleEndian, WriteBytesExt}; +use crc32c::crc32c_append; +use parking_lot::{Mutex, MutexGuard}; +use postgres_ffi::{ + pg_constants::{ + RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLP_LONG_HEADER, XLR_BLOCK_ID_DATA_LONG, + XLR_BLOCK_ID_DATA_SHORT, + }, + v16::{ + wal_craft_test_export::{XLogLongPageHeaderData, XLogPageHeaderData, XLOG_PAGE_MAGIC}, + xlog_utils::{ + XLogSegNoOffsetToRecPtr, XlLogicalMessage, XLOG_RECORD_CRC_OFFS, + XLOG_SIZE_OF_XLOG_LONG_PHD, XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD, + XLP_FIRST_IS_CONTRECORD, + }, + XLogRecord, + }, + WAL_SEGMENT_SIZE, XLOG_BLCKSZ, +}; +use utils::lsn::Lsn; + +use super::block_storage::BlockStorage; + +/// Simulation implementation of walproposer WAL storage. +pub struct DiskWalProposer { + state: Mutex, +} + +impl DiskWalProposer { + pub fn new() -> Arc { + Arc::new(DiskWalProposer { + state: Mutex::new(State { + internal_available_lsn: Lsn(0), + prev_lsn: Lsn(0), + disk: BlockStorage::new(), + }), + }) + } + + pub fn lock(&self) -> MutexGuard { + self.state.lock() + } +} + +pub struct State { + // flush_lsn + internal_available_lsn: Lsn, + // needed for WAL generation + prev_lsn: Lsn, + // actual WAL storage + disk: BlockStorage, +} + +impl State { + pub fn read(&self, pos: u64, buf: &mut [u8]) { + self.disk.read(pos, buf); + // TODO: fail on reading uninitialized data + } + + pub fn write(&mut self, pos: u64, buf: &[u8]) { + self.disk.write(pos, buf); + } + + /// Update the internal available LSN to the given value. + pub fn reset_to(&mut self, lsn: Lsn) { + self.internal_available_lsn = lsn; + } + + /// Get current LSN. + pub fn flush_rec_ptr(&self) -> Lsn { + self.internal_available_lsn + } + + /// Generate a new WAL record at the current LSN. + pub fn insert_logical_message(&mut self, prefix: &str, msg: &[u8]) -> anyhow::Result<()> { + let prefix_cstr = CString::new(prefix)?; + let prefix_bytes = prefix_cstr.as_bytes_with_nul(); + + let lm = XlLogicalMessage { + db_id: 0, + transactional: 0, + prefix_size: prefix_bytes.len() as ::std::os::raw::c_ulong, + message_size: msg.len() as ::std::os::raw::c_ulong, + }; + + let record_bytes = lm.encode(); + let rdatas: Vec<&[u8]> = vec![&record_bytes, prefix_bytes, msg]; + insert_wal_record(self, rdatas, RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE) + } +} + +fn insert_wal_record( + state: &mut State, + rdatas: Vec<&[u8]>, + rmid: u8, + info: u8, +) -> anyhow::Result<()> { + // bytes right after the header, in the same rdata block + let mut scratch = Vec::new(); + let mainrdata_len: usize = rdatas.iter().map(|rdata| rdata.len()).sum(); + + if mainrdata_len > 0 { + if mainrdata_len > 255 { + scratch.push(XLR_BLOCK_ID_DATA_LONG); + // TODO: verify endiness + let _ = scratch.write_u32::(mainrdata_len as u32); + } else { + scratch.push(XLR_BLOCK_ID_DATA_SHORT); + scratch.push(mainrdata_len as u8); + } + } + + let total_len: u32 = (XLOG_SIZE_OF_XLOG_RECORD + scratch.len() + mainrdata_len) as u32; + let size = maxalign(total_len); + assert!(size as usize > XLOG_SIZE_OF_XLOG_RECORD); + + let start_bytepos = recptr_to_bytepos(state.internal_available_lsn); + let end_bytepos = start_bytepos + size as u64; + + let start_recptr = bytepos_to_recptr(start_bytepos); + let end_recptr = bytepos_to_recptr(end_bytepos); + + assert!(recptr_to_bytepos(start_recptr) == start_bytepos); + assert!(recptr_to_bytepos(end_recptr) == end_bytepos); + + let mut crc = crc32c_append(0, &scratch); + for rdata in &rdatas { + crc = crc32c_append(crc, rdata); + } + + let mut header = XLogRecord { + xl_tot_len: total_len, + xl_xid: 0, + xl_prev: state.prev_lsn.0, + xl_info: info, + xl_rmid: rmid, + __bindgen_padding_0: [0u8; 2usize], + xl_crc: crc, + }; + + // now we have the header and can finish the crc + let header_bytes = header.encode()?; + let crc = crc32c_append(crc, &header_bytes[0..XLOG_RECORD_CRC_OFFS]); + header.xl_crc = crc; + + let mut header_bytes = header.encode()?.to_vec(); + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_RECORD); + + header_bytes.extend_from_slice(&scratch); + + // finish rdatas + let mut rdatas = rdatas; + rdatas.insert(0, &header_bytes); + + write_walrecord_to_disk(state, total_len as u64, rdatas, start_recptr, end_recptr)?; + + state.internal_available_lsn = end_recptr; + state.prev_lsn = start_recptr; + Ok(()) +} + +fn write_walrecord_to_disk( + state: &mut State, + total_len: u64, + rdatas: Vec<&[u8]>, + start: Lsn, + end: Lsn, +) -> anyhow::Result<()> { + let mut curr_ptr = start; + let mut freespace = insert_freespace(curr_ptr); + let mut written: usize = 0; + + assert!(freespace >= std::mem::size_of::()); + + for mut rdata in rdatas { + while rdata.len() >= freespace { + assert!( + curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD + || freespace == 0 + ); + + state.write(curr_ptr.0, &rdata[..freespace]); + rdata = &rdata[freespace..]; + written += freespace; + curr_ptr = Lsn(curr_ptr.0 + freespace as u64); + + let mut new_page = XLogPageHeaderData { + xlp_magic: XLOG_PAGE_MAGIC as u16, + xlp_info: XLP_BKP_REMOVABLE, + xlp_tli: 1, + xlp_pageaddr: curr_ptr.0, + xlp_rem_len: (total_len - written as u64) as u32, + ..Default::default() // Put 0 in padding fields. + }; + if new_page.xlp_rem_len > 0 { + new_page.xlp_info |= XLP_FIRST_IS_CONTRECORD; + } + + if curr_ptr.segment_offset(WAL_SEGMENT_SIZE) == 0 { + new_page.xlp_info |= XLP_LONG_HEADER; + let long_page = XLogLongPageHeaderData { + std: new_page, + xlp_sysid: 0, + xlp_seg_size: WAL_SEGMENT_SIZE as u32, + xlp_xlog_blcksz: XLOG_BLCKSZ as u32, + }; + let header_bytes = long_page.encode()?; + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_LONG_PHD); + state.write(curr_ptr.0, &header_bytes); + curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64); + } else { + let header_bytes = new_page.encode()?; + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_SHORT_PHD); + state.write(curr_ptr.0, &header_bytes); + curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64); + } + freespace = insert_freespace(curr_ptr); + } + + assert!( + curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD + || rdata.is_empty() + ); + state.write(curr_ptr.0, rdata); + curr_ptr = Lsn(curr_ptr.0 + rdata.len() as u64); + written += rdata.len(); + freespace -= rdata.len(); + } + + assert!(written == total_len as usize); + curr_ptr.0 = maxalign(curr_ptr.0); + assert!(curr_ptr == end); + Ok(()) +} + +fn maxalign(size: T) -> T +where + T: std::ops::BitAnd + + std::ops::Add + + std::ops::Not + + From, +{ + (size + T::from(7)) & !T::from(7) +} + +fn insert_freespace(ptr: Lsn) -> usize { + if ptr.block_offset() == 0 { + 0 + } else { + (XLOG_BLCKSZ as u64 - ptr.block_offset()) as usize + } +} + +const XLP_BKP_REMOVABLE: u16 = 0x0004; +const USABLE_BYTES_IN_PAGE: u64 = (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; +const USABLE_BYTES_IN_SEGMENT: u64 = ((WAL_SEGMENT_SIZE / XLOG_BLCKSZ) as u64 + * USABLE_BYTES_IN_PAGE) + - (XLOG_SIZE_OF_XLOG_RECORD - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; + +fn bytepos_to_recptr(bytepos: u64) -> Lsn { + let fullsegs = bytepos / USABLE_BYTES_IN_SEGMENT; + let mut bytesleft = bytepos % USABLE_BYTES_IN_SEGMENT; + + let seg_offset = if bytesleft < (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64 { + // fits on first page of segment + bytesleft + XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + // account for the first page on segment with long header + bytesleft -= (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; + let fullpages = bytesleft / USABLE_BYTES_IN_PAGE; + bytesleft %= USABLE_BYTES_IN_PAGE; + + XLOG_BLCKSZ as u64 + + fullpages * XLOG_BLCKSZ as u64 + + bytesleft + + XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + }; + + Lsn(XLogSegNoOffsetToRecPtr( + fullsegs, + seg_offset as u32, + WAL_SEGMENT_SIZE, + )) +} + +fn recptr_to_bytepos(ptr: Lsn) -> u64 { + let fullsegs = ptr.segment_number(WAL_SEGMENT_SIZE); + let offset = ptr.segment_offset(WAL_SEGMENT_SIZE) as u64; + + let fullpages = offset / XLOG_BLCKSZ as u64; + let offset = offset % XLOG_BLCKSZ as u64; + + if fullpages == 0 { + fullsegs * USABLE_BYTES_IN_SEGMENT + + if offset > 0 { + assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64); + offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + 0 + } + } else { + fullsegs * USABLE_BYTES_IN_SEGMENT + + (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64 + + (fullpages - 1) * USABLE_BYTES_IN_PAGE + + if offset > 0 { + assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64); + offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + 0 + } + } +} From a8eb4042baa6ca1ae4268a1f1b22a89941b0d942 Mon Sep 17 00:00:00 2001 From: John Spray Date: Tue, 13 Feb 2024 07:00:50 +0000 Subject: [PATCH 57/81] tests: test_secondary_mode_eviction: avoid use of mocked statvfs (#6698) ## Problem Test sometimes fails with `used_blocks > total_blocks`, because when using mocked statvfs with the total blocks set to the size of data on disk before starting, we are implicitly asserting that nothing at all can be written to disk between startup and calling statvfs. Related: https://github.com/neondatabase/neon/issues/6511 ## Summary of changes - Use HTTP API to invoke disk usage eviction instead of mocked statvfs --- .../regress/test_disk_usage_eviction.py | 33 +++---------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index 061c57c88b..eb4e370ea7 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -893,37 +893,14 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv): # in its heatmap ps_secondary.http_client().tenant_secondary_download(tenant_id) - # Configure the secondary pageserver to have a phony small disk size - ps_secondary.stop() total_size, _, _ = env.timelines_du(ps_secondary) - blocksize = 512 - total_blocks = (total_size + (blocksize - 1)) // blocksize + evict_bytes = total_size // 3 - min_avail_bytes = total_size // 3 - - env.pageserver_start_with_disk_usage_eviction( - ps_secondary, - period="1s", - max_usage_pct=100, - min_avail_bytes=min_avail_bytes, - mock_behavior={ - "type": "Success", - "blocksize": blocksize, - "total_blocks": total_blocks, - # Only count layer files towards used bytes in the mock_statvfs. - # This avoids accounting for metadata files & tenant conf in the tests. - "name_filter": ".*__.*", - }, - eviction_order=EvictionOrder.ABSOLUTE_ORDER, - ) - - def relieved_log_message(): - assert ps_secondary.log_contains(".*disk usage pressure relieved") - - wait_until(10, 1, relieved_log_message) + response = ps_secondary.http_client().disk_usage_eviction_run({"evict_bytes": evict_bytes}) + log.info(f"{response}") post_eviction_total_size, _, _ = env.timelines_du(ps_secondary) assert ( - total_size - post_eviction_total_size >= min_avail_bytes - ), "we requested at least min_avail_bytes worth of free space" + total_size - post_eviction_total_size >= evict_bytes + ), "we requested at least evict_bytes worth of free space" From 331935df91abe03a1e8a081bc96b6ef871f71bb1 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:58:58 +0100 Subject: [PATCH 58/81] Proxy: send cancel notifications to all instances (#6719) ## Problem If cancel request ends up on the wrong proxy instance, it doesn't take an effect. ## Summary of changes Send redis notifications to all proxy pods about the cancel request. Related issue: https://github.com/neondatabase/neon/issues/5839, https://github.com/neondatabase/cloud/issues/10262 --- Cargo.lock | 7 +- Cargo.toml | 2 +- libs/pq_proto/Cargo.toml | 1 + libs/pq_proto/src/lib.rs | 3 +- proxy/src/bin/proxy.rs | 32 ++++- proxy/src/cancellation.rs | 109 ++++++++++++++--- proxy/src/config.rs | 1 + proxy/src/metrics.rs | 9 ++ proxy/src/proxy.rs | 16 +-- proxy/src/rate_limiter.rs | 2 +- proxy/src/rate_limiter/limiter.rs | 38 ++++++ proxy/src/redis.rs | 1 + proxy/src/redis/notifications.rs | 197 +++++++++++++++++++++++------- proxy/src/redis/publisher.rs | 80 ++++++++++++ proxy/src/serverless.rs | 13 +- proxy/src/serverless/websocket.rs | 6 +- workspace_hack/Cargo.toml | 4 +- 17 files changed, 432 insertions(+), 89 deletions(-) create mode 100644 proxy/src/redis/publisher.rs diff --git a/Cargo.lock b/Cargo.lock index f11c774016..45a313a72b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2263,11 +2263,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.2" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] @@ -3952,6 +3952,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "rand 0.8.5", + "serde", "thiserror", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 8df9ca9988..8952f7627f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ futures-core = "0.3" futures-util = "0.3" git-version = "0.3" hashbrown = "0.13" -hashlink = "0.8.1" +hashlink = "0.8.4" hdrhistogram = "7.5.2" hex = "0.4" hex-literal = "0.4" diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index b286eb0358..6eeb3bafef 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -13,5 +13,6 @@ rand.workspace = true tokio.workspace = true tracing.workspace = true thiserror.workspace = true +serde.workspace = true workspace_hack.workspace = true diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index c52a21bcd3..522b65f5d1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -7,6 +7,7 @@ pub mod framed; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; use std::{borrow::Cow, collections::HashMap, fmt, io, str}; // re-export for use in utils pageserver_feedback.rs @@ -123,7 +124,7 @@ impl StartupMessageParams { } } -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub struct CancelKeyData { pub backend_pid: i32, pub cancel_key: i32, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 00a229c135..b3d4fc0411 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,6 +1,8 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; +use proxy::cancellation::CancelMap; +use proxy::cancellation::CancellationHandler; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; use proxy::redis::notifications; +use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -22,6 +25,7 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; @@ -129,6 +133,9 @@ struct ProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] endpoint_rps_limit: Vec, + /// Redis rate limiter max number of requests per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + redis_rps_limit: Vec, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] initial_limit: usize, @@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); + let cancel_map = CancelMap::default(); + let redis_publisher = match &args.redis_notifications { + Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + url, + args.region.clone(), + &config.redis_rps_limit, + )?))), + None => None, + }; + let cancellation_handler = Arc::new(CancellationHandler::new( + cancel_map.clone(), + redis_publisher, + )); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) @@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); // TODO: rename the argument to something like serverless. @@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> { serverless_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); } @@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(url) = args.redis_notifications { info!("Starting redis notifications listener ({url})"); - maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + maintenance_tasks.spawn(notifications::task_main( + url.to_owned(), + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); } maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } @@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let mut redis_rps_limit = args.redis_rps_limit.clone(); + RateBucketInfo::validate(&mut redis_rps_limit)?; let config = Box::leak(Box::new(ProxyConfig { tls_config, @@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, endpoint_rps_limit, + redis_rps_limit, handshake_timeout: args.handshake_timeout, // TODO: add this argument region: args.region.clone(), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index fe614628d8..93a77bc4ae 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,16 +1,28 @@ +use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; use tokio::net::TcpStream; +use tokio::sync::Mutex; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use uuid::Uuid; -use crate::error::ReportableError; +use crate::{ + error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, + redis::publisher::RedisPublisherClient, +}; + +pub type CancelMap = Arc>>; /// Enables serving `CancelRequest`s. -#[derive(Default)] -pub struct CancelMap(DashMap>); +/// +/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler { + map: CancelMap, + redis_client: Option>>, +} #[derive(Debug, Error)] pub enum CancelError { @@ -32,15 +44,43 @@ impl ReportableError for CancelError { } } -impl CancelMap { +impl CancellationHandler { + pub fn new(map: CancelMap, redis_client: Option>>) -> Self { + Self { map, redis_client } + } /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + let from = "from_client"; // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); + if let Some(redis_client) = &self.redis_client { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + info!("publishing cancellation key to Redis"); + match redis_client.lock().await.try_publish(key, session_id).await { + Ok(()) => { + info!("cancellation key successfuly published to Redis"); + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + } return Ok(()); }; - + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await } @@ -57,7 +97,7 @@ impl CancelMap { // Random key collisions are unlikely to happen here, but they're still possible, // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { + match self.map.entry(key) { dashmap::mapref::entry::Entry::Occupied(_) => continue, dashmap::mapref::entry::Entry::Vacant(e) => { e.insert(None); @@ -69,18 +109,46 @@ impl CancelMap { info!("registered new query cancellation key {key}"); Session { key, - cancel_map: self, + cancellation_handler: self, } } #[cfg(test)] fn contains(&self, session: &Session) -> bool { - self.0.contains_key(&session.key) + self.map.contains_key(&session.key) } #[cfg(test)] fn is_empty(&self) -> bool { - self.0.is_empty() + self.map.is_empty() + } +} + +#[async_trait] +pub trait NotificationsCancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +} + +#[async_trait] +impl NotificationsCancellationHandler for CancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { + let from = "from_redis"; + let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); + match cancel_closure { + Some(cancel_closure) => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + cancel_closure.try_cancel_query().await + } + None => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + tracing::warn!("query cancellation key not found: {key}"); + Ok(()) + } + } } } @@ -115,7 +183,7 @@ pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: Arc, + cancellation_handler: Arc, } impl Session { @@ -123,7 +191,9 @@ impl Session { /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); - self.cancel_map.0.insert(self.key, Some(cancel_closure)); + self.cancellation_handler + .map + .insert(self.key, Some(cancel_closure)); self.key } @@ -131,7 +201,7 @@ impl Session { impl Drop for Session { fn drop(&mut self) { - self.cancel_map.0.remove(&self.key); + self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); } } @@ -142,13 +212,16 @@ mod tests { #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancel_map: Arc = Default::default(); + let cancellation_handler = Arc::new(CancellationHandler { + map: CancelMap::default(), + redis_client: None, + }); - let session = cancel_map.clone().get_session(); - assert!(cancel_map.contains(&session)); + let session = cancellation_handler.clone().get_session(); + assert!(cancellation_handler.contains(&session)); drop(session); // Check that the session has been dropped. - assert!(cancel_map.is_empty()); + assert!(cancellation_handler.is_empty()); Ok(()) } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 5fcb537834..9f276c3c24 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -21,6 +21,7 @@ pub struct ProxyConfig { pub require_client_ip: bool, pub disable_ip_check_for_http: bool, pub endpoint_rps_limit: Vec, + pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f7f162a075..66031f5eb2 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { .unwrap() }); +pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_cancellation_requests_total", + "Number of cancellation requests (per found/not_found).", + &["source", "kind"], + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 5f65de4c98..ce77098a5f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancelMap}, + cancellation::{self, CancellationHandler}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,6 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -72,7 +73,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancel_map = Arc::new(CancelMap::default()); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -80,7 +80,7 @@ pub async fn task_main( let (socket, peer_addr) = accept_result?; let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); + let cancellation_handler = Arc::clone(&cancellation_handler); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let session_span = info_span!( @@ -113,7 +113,7 @@ pub async fn task_main( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, socket, ClientMode::Tcp, endpoint_rate_limiter, @@ -227,7 +227,7 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancel_map: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, @@ -253,8 +253,8 @@ pub async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancel_map - .cancel_session(cancel_key_data) + return Ok(cancellation_handler + .cancel_session(cancel_key_data, ctx.session_id) .await .map(|()| None)?) } @@ -315,7 +315,7 @@ pub async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; - let session = cancel_map.get_session(); + let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index b26386d159..f0da4ead23 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{EndpointRateLimiter, RateBucketInfo}; +pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index cbae72711c..3181060e2f 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -22,6 +22,44 @@ use super::{ RateLimiterConfig, }; +pub struct RedisRateLimiter { + data: Vec, + info: &'static [RateBucketInfo], +} + +impl RedisRateLimiter { + pub fn new(info: &'static [RateBucketInfo]) -> Self { + Self { + data: vec![ + RateBucket { + start: Instant::now(), + count: 0, + }; + info.len() + ], + info, + } + } + + /// Check that number of connections is below `max_rps` rps. + pub fn check(&mut self) -> bool { + let now = Instant::now(); + + let should_allow_request = self + .data + .iter_mut() + .zip(self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now)); + + if should_allow_request { + // only increment the bucket counts if the request will actually be accepted + self.data.iter_mut().for_each(RateBucket::inc); + } + + should_allow_request + } +} + // Simple per-endpoint rate limiter. // // Check that number of connections to the endpoint is below `max_rps` rps. diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index c2a91bed97..35d6db074e 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1 +1,2 @@ pub mod notifications; +pub mod publisher; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 158884aa17..b8297a206c 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -1,38 +1,44 @@ use std::{convert::Infallible, sync::Arc}; use futures::StreamExt; +use pq_proto::CancelKeyData; use redis::aio::PubSub; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::{ cache::project_info::ProjectInfoCache, + cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, }; -const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct ConsoleRedisClient { +struct RedisConsumerClient { client: redis::Client, } -impl ConsoleRedisClient { +impl RedisConsumerClient { pub fn new(url: &str) -> anyhow::Result { let client = redis::Client::open(url)?; Ok(Self { client }) } async fn try_connect(&self) -> anyhow::Result { let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CHANNEL_NAME}`"); - conn.subscribe(CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; Ok(conn) } } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -enum Notification { +pub(crate) enum Notification { #[serde( rename = "/allowed_ips_updated", deserialize_with = "deserialize_json_string" @@ -45,16 +51,25 @@ enum Notification { deserialize_with = "deserialize_json_string" )] PasswordUpdate { password_update: PasswordUpdate }, + #[serde(rename = "/cancel_session")] + Cancel(CancelSession), } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct AllowedIpsUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedIpsUpdate { project_id: ProjectIdInt, } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct PasswordUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct PasswordUpdate { project_id: ProjectIdInt, role_name: RoleNameInt, } +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct CancelSession { + pub region_id: Option, + pub cancel_key_data: CancelKeyData, + pub session_id: Uuid, +} + fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -64,6 +79,88 @@ where serde_json::from_str(&s).map_err(::custom) } +struct MessageHandler< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, +> { + cache: Arc, + cancellation_handler: Arc, + region_id: String, +} + +impl< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, + > MessageHandler +{ + pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { + Self { + cache, + cancellation_handler, + region_id, + } + } + pub fn disable_ttl(&self) { + self.cache.disable_ttl(); + } + pub fn enable_ttl(&self) { + self.cache.enable_ttl(); + } + #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] + async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { + use Notification::*; + let payload: String = msg.get_payload()?; + tracing::debug!(?payload, "received a message payload"); + + let msg: Notification = match serde_json::from_str(&payload) { + Ok(msg) => msg, + Err(e) => { + tracing::error!("broken message: {e}"); + return Ok(()); + } + }; + tracing::debug!(?msg, "received a message"); + match msg { + Cancel(cancel_session) => { + tracing::Span::current().record( + "session_id", + &tracing::field::display(cancel_session.session_id), + ); + if let Some(cancel_region) = cancel_session.region_id { + // If the message is not for this region, ignore it. + if cancel_region != self.region_id { + return Ok(()); + } + } + // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. + match self + .cancellation_handler + .cancel_session_no_publish(cancel_session.cancel_key_data) + .await + { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to cancel session: {e}"); + } + } + } + _ => { + invalidate_cache(self.cache.clone(), msg.clone()); + // It might happen that the invalid entry is on the way to be cached. + // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. + // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. + let cache = self.cache.clone(); + tokio::spawn(async move { + tokio::time::sleep(INVALIDATION_LAG).await; + invalidate_cache(cache, msg); + }); + } + } + + Ok(()) + } +} + fn invalidate_cache(cache: Arc, msg: Notification) { use Notification::*; match msg { @@ -74,50 +171,33 @@ fn invalidate_cache(cache: Arc, msg: Notification) { password_update.project_id, password_update.role_name, ), + Cancel(_) => unreachable!("cancel message should be handled separately"), } } -#[tracing::instrument(skip(cache))] -fn handle_message(msg: redis::Msg, cache: Arc) -> anyhow::Result<()> -where - C: ProjectInfoCache + Send + Sync + 'static, -{ - let payload: String = msg.get_payload()?; - tracing::debug!(?payload, "received a message payload"); - - let msg: Notification = match serde_json::from_str(&payload) { - Ok(msg) => msg, - Err(e) => { - tracing::error!("broken message: {e}"); - return Ok(()); - } - }; - tracing::debug!(?msg, "received a message"); - invalidate_cache(cache.clone(), msg.clone()); - // It might happen that the invalid entry is on the way to be cached. - // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. - // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. - tokio::spawn(async move { - tokio::time::sleep(INVALIDATION_LAG).await; - invalidate_cache(cache, msg.clone()); - }); - - Ok(()) -} - /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] -pub async fn task_main(url: String, cache: Arc) -> anyhow::Result +pub async fn task_main( + url: String, + cache: Arc, + cancel_map: CancelMap, + region_id: String, +) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { cache.enable_ttl(); + let handler = MessageHandler::new( + cache, + Arc::new(CancellationHandler::new(cancel_map, None)), + region_id, + ); loop { - let redis = ConsoleRedisClient::new(&url)?; + let redis = RedisConsumerClient::new(&url)?; let conn = match redis.try_connect().await { Ok(conn) => { - cache.disable_ttl(); + handler.disable_ttl(); conn } Err(e) => { @@ -130,7 +210,7 @@ where }; let mut stream = conn.into_on_message(); while let Some(msg) = stream.next().await { - match handle_message(msg, cache.clone()) { + match handler.handle_message(msg).await { Ok(()) => {} Err(e) => { tracing::error!("failed to handle message: {e}, will try to reconnect"); @@ -138,7 +218,7 @@ where } } } - cache.enable_ttl(); + handler.enable_ttl(); } } @@ -198,6 +278,33 @@ mod tests { } ); + Ok(()) + } + #[test] + fn parse_cancel_session() -> anyhow::Result<()> { + let cancel_key_data = CancelKeyData { + backend_pid: 42, + cancel_key: 41, + }; + let uuid = uuid::Uuid::new_v4(); + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: None, + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result); + + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: Some("region".to_string()), + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result,); + Ok(()) } } diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs new file mode 100644 index 0000000000..f85593afdd --- /dev/null +++ b/proxy/src/redis/publisher.rs @@ -0,0 +1,80 @@ +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; + +pub struct RedisPublisherClient { + client: redis::Client, + publisher: Option, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + url: &str, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { + client, + publisher: None, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + pub async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + self.publisher = None; + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let conn = self + .publisher + .as_mut() + .ok_or_else(|| anyhow::anyhow!("not connected"))?; + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + conn.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.get_async_connection().await { + Ok(conn) => { + self.publisher = Some(conn); + } + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e.into()); + } + } + Ok(()) + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a20600b94a..ee3e91495b 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancelMap, config::ProxyConfig}; +use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use futures::StreamExt; use hyper::{ server::{ @@ -50,6 +50,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -115,7 +116,7 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - + let cancellation_handler = cancellation_handler.clone(); async move { let peer_addr = match client_addr { Some(addr) => addr, @@ -127,9 +128,9 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); async move { - let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); request_handler( @@ -137,7 +138,7 @@ pub async fn task_main( config, backend, ws_connections, - cancel_map, + cancellation_handler, session_id, peer_addr.ip(), endpoint_rate_limiter, @@ -205,7 +206,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancel_map: Arc, + cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, endpoint_rate_limiter: Arc, @@ -232,7 +233,7 @@ async fn request_handler( config, ctx, websocket, - cancel_map, + cancellation_handler, host, endpoint_rate_limiter, ) diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 062dd440b2..24f2bb7e8c 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancelMap, + cancellation::CancellationHandler, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -133,7 +133,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancel_map: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { @@ -141,7 +141,7 @@ pub async fn serve_websocket( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 8e9cc43152..e808fabbe7 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -38,7 +38,7 @@ futures-io = { version = "0.3" } futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } @@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits", "use_std"] } From 7fa732c96c6382fd0468991b40f922348e653d3c Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 13 Feb 2024 18:46:25 +0100 Subject: [PATCH 59/81] refactor(virtual_file): take owned buffer in VirtualFile::write_all (#6664) Building atop #6660 , this PR converts VirtualFile::write_all to owned buffers. Part of https://github.com/neondatabase/neon/issues/6663 --- pageserver/src/deletion_queue.rs | 4 +- pageserver/src/tenant.rs | 4 +- pageserver/src/tenant/blob_io.rs | 26 ++++---- pageserver/src/tenant/metadata.rs | 2 +- pageserver/src/tenant/secondary/downloader.rs | 2 +- .../src/tenant/storage_layer/delta_layer.rs | 30 +++------ .../src/tenant/storage_layer/image_layer.rs | 30 +++------ pageserver/src/virtual_file.rs | 66 ++++++++++++------- 8 files changed, 81 insertions(+), 83 deletions(-) diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index da1da9331a..9046fe881b 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -234,7 +234,7 @@ impl DeletionHeader { let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?; let header_path = conf.deletion_header_path(); let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX); - VirtualFile::crashsafe_overwrite(&header_path, &temp_path, &header_bytes) + VirtualFile::crashsafe_overwrite(&header_path, &temp_path, header_bytes) .await .maybe_fatal_err("save deletion header")?; @@ -325,7 +325,7 @@ impl DeletionList { let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX); let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list"); - VirtualFile::crashsafe_overwrite(&path, &temp_path, &bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, bytes) .await .maybe_fatal_err("save deletion list") .map_err(Into::into) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index d946c57118..9f1f188bf2 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -2880,7 +2880,7 @@ impl Tenant { let config_path = config_path.to_owned(); tokio::task::spawn_blocking(move || { Handle::current().block_on(async move { - let conf_content = conf_content.as_bytes(); + let conf_content = conf_content.into_bytes(); VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content) .await .with_context(|| { @@ -2917,7 +2917,7 @@ impl Tenant { let target_config_path = target_config_path.to_owned(); tokio::task::spawn_blocking(move || { Handle::current().block_on(async move { - let conf_content = conf_content.as_bytes(); + let conf_content = conf_content.into_bytes(); VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content) .await .with_context(|| { diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index e2ff12665a..ec70bdc679 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -131,27 +131,23 @@ impl BlobWriter { &mut self, src_buf: B, ) -> (B::Buf, Result<(), Error>) { - let src_buf_len = src_buf.bytes_init(); - let (src_buf, res) = if src_buf_len > 0 { - let src_buf = src_buf.slice(0..src_buf_len); - let res = self.inner.write_all(&src_buf).await; - let src_buf = Slice::into_inner(src_buf); - (src_buf, res) - } else { - let res = self.inner.write_all(&[]).await; - (Slice::into_inner(src_buf.slice_full()), res) + let (src_buf, res) = self.inner.write_all(src_buf).await; + let nbytes = match res { + Ok(nbytes) => nbytes, + Err(e) => return (src_buf, Err(e)), }; - if let Ok(()) = &res { - self.offset += src_buf_len as u64; - } - (src_buf, res) + self.offset += nbytes as u64; + (src_buf, Ok(())) } #[inline(always)] /// Flushes the internal buffer to the underlying `VirtualFile`. pub async fn flush_buffer(&mut self) -> Result<(), Error> { - self.inner.write_all(&self.buf).await?; - self.buf.clear(); + let buf = std::mem::take(&mut self.buf); + let (mut buf, res) = self.inner.write_all(buf).await; + res?; + buf.clear(); + self.buf = buf; Ok(()) } diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index 6fb86c65e2..dcbe781f90 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -279,7 +279,7 @@ pub async fn save_metadata( let path = conf.metadata_path(tenant_shard_id, timeline_id); let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX); let metadata_bytes = data.to_bytes().context("serialize metadata")?; - VirtualFile::crashsafe_overwrite(&path, &temp_path, &metadata_bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, metadata_bytes) .await .context("write metadata")?; Ok(()) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 0666e104f8..c23416a7f0 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -486,7 +486,7 @@ impl<'a> TenantDownloader<'a> { let heatmap_path_bg = heatmap_path.clone(); tokio::task::spawn_blocking(move || { tokio::runtime::Handle::current().block_on(async move { - VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, &heatmap_bytes).await + VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, heatmap_bytes).await }) }) .await diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 7a5dc7a59f..9a7bcbcebe 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -461,7 +461,8 @@ impl DeltaLayerWriterInner { file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64)) .await?; for buf in block_buf.blocks { - file.write_all(buf.as_ref()).await?; + let (_buf, res) = file.write_all(buf).await; + res?; } assert!(self.lsn_range.start < self.lsn_range.end); // Fill in the summary on blk 0 @@ -476,17 +477,12 @@ impl DeltaLayerWriterInner { index_root_blk, }; - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&summary, &mut buf)?; - if buf.spilled() { - // This is bad as we only have one free block for the summary - warn!( - "Used more than one page size for summary buffer: {}", - buf.len() - ); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; let metadata = file .metadata() @@ -679,18 +675,12 @@ impl DeltaLayer { let new_summary = rewrite(actual_summary); - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here, but it's a pain with Slice Summary::ser_into(&new_summary, &mut buf).context("serialize")?; - if buf.spilled() { - // The code in DeltaLayerWriterInner just warn!()s for this. - // It should probably error out as well. - return Err(RewriteSummaryError::Other(anyhow::anyhow!( - "Used more than one page size for summary buffer: {}", - buf.len() - ))); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; Ok(()) } } diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 1ad195032d..458131b572 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -341,18 +341,12 @@ impl ImageLayer { let new_summary = rewrite(actual_summary); - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&new_summary, &mut buf).context("serialize")?; - if buf.spilled() { - // The code in ImageLayerWriterInner just warn!()s for this. - // It should probably error out as well. - return Err(RewriteSummaryError::Other(anyhow::anyhow!( - "Used more than one page size for summary buffer: {}", - buf.len() - ))); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; Ok(()) } } @@ -555,7 +549,8 @@ impl ImageLayerWriterInner { .await?; let (index_root_blk, block_buf) = self.tree.finish()?; for buf in block_buf.blocks { - file.write_all(buf.as_ref()).await?; + let (_buf, res) = file.write_all(buf).await; + res?; } // Fill in the summary on blk 0 @@ -570,17 +565,12 @@ impl ImageLayerWriterInner { index_root_blk, }; - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&summary, &mut buf)?; - if buf.spilled() { - // This is bad as we only have one free block for the summary - warn!( - "Used more than one page size for summary buffer: {}", - buf.len() - ); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; let metadata = file .metadata() diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 059a6596d3..6cff748d42 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -19,7 +19,7 @@ use once_cell::sync::OnceCell; use pageserver_api::shard::TenantShardId; use std::fs::{self, File}; use std::io::{Error, ErrorKind, Seek, SeekFrom}; -use tokio_epoll_uring::IoBufMut; +use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::fs::FileExt; @@ -410,10 +410,10 @@ impl VirtualFile { /// step, the tmp path is renamed to the final path. As renames are /// atomic, a crash during the write operation will never leave behind a /// partially written file. - pub async fn crashsafe_overwrite( + pub async fn crashsafe_overwrite( final_path: &Utf8Path, tmp_path: &Utf8Path, - content: &[u8], + content: B, ) -> std::io::Result<()> { let Some(final_path_parent) = final_path.parent() else { return Err(std::io::Error::from_raw_os_error( @@ -430,7 +430,8 @@ impl VirtualFile { .create_new(true), ) .await?; - file.write_all(content).await?; + let (_content, res) = file.write_all(content).await; + res?; file.sync_all().await?; drop(file); // before the rename, that's important! // renames are atomic @@ -601,23 +602,36 @@ impl VirtualFile { Ok(()) } - pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Error> { + /// Writes `buf.slice(0..buf.bytes_init())`. + /// Returns the IoBuf that is underlying the BoundedBuf `buf`. + /// I.e., the returned value's `bytes_init()` method returns something different than the `bytes_init()` that was passed in. + /// It's quite brittle and easy to mis-use, so, we return the size in the Ok() variant. + pub async fn write_all(&mut self, buf: B) -> (B::Buf, Result) { + let nbytes = buf.bytes_init(); + if nbytes == 0 { + return (Slice::into_inner(buf.slice_full()), Ok(0)); + } + let mut buf = buf.slice(0..nbytes); while !buf.is_empty() { - match self.write(buf).await { + // TODO: push `Slice` further down + match self.write(&buf).await { Ok(0) => { - return Err(Error::new( - std::io::ErrorKind::WriteZero, - "failed to write whole buffer", - )); + return ( + Slice::into_inner(buf), + Err(Error::new( + std::io::ErrorKind::WriteZero, + "failed to write whole buffer", + )), + ); } Ok(n) => { - buf = &buf[n..]; + buf = buf.slice(n..); } Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} - Err(e) => return Err(e), + Err(e) => return (Slice::into_inner(buf), Err(e)), } } - Ok(()) + (Slice::into_inner(buf), Ok(nbytes)) } async fn write(&mut self, buf: &[u8]) -> Result { @@ -676,7 +690,6 @@ where F: FnMut(tokio_epoll_uring::Slice, u64) -> Fut, Fut: std::future::Future, std::io::Result)>, { - use tokio_epoll_uring::BoundedBuf; let mut buf: tokio_epoll_uring::Slice = buf.slice_full(); // includes all the uninitialized memory while buf.bytes_total() != 0 { let res; @@ -1063,10 +1076,19 @@ mod tests { MaybeVirtualFile::File(file) => file.seek(pos), } } - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> { + async fn write_all(&mut self, buf: B) -> Result<(), Error> { match self { - MaybeVirtualFile::VirtualFile(file) => file.write_all(buf).await, - MaybeVirtualFile::File(file) => file.write_all(buf), + MaybeVirtualFile::VirtualFile(file) => { + let (_buf, res) = file.write_all(buf).await; + res.map(|_| ()) + } + MaybeVirtualFile::File(file) => { + let buf_len = buf.bytes_init(); + if buf_len == 0 { + return Ok(()); + } + file.write_all(&buf.slice(0..buf_len)) + } } } @@ -1141,7 +1163,7 @@ mod tests { .to_owned(), ) .await?; - file_a.write_all(b"foobar").await?; + file_a.write_all(b"foobar".to_vec()).await?; // cannot read from a file opened in write-only mode let _ = file_a.read_string().await.unwrap_err(); @@ -1150,7 +1172,7 @@ mod tests { let mut file_a = openfunc(path_a, OpenOptions::new().read(true).to_owned()).await?; // cannot write to a file opened in read-only mode - let _ = file_a.write_all(b"bar").await.unwrap_err(); + let _ = file_a.write_all(b"bar".to_vec()).await.unwrap_err(); // Try simple read assert_eq!("foobar", file_a.read_string().await?); @@ -1293,7 +1315,7 @@ mod tests { let path = testdir.join("myfile"); let tmp_path = testdir.join("myfile.tmp"); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1302,7 +1324,7 @@ mod tests { assert!(!tmp_path.exists()); drop(file); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1324,7 +1346,7 @@ mod tests { std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap(); assert!(tmp_path.exists()); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); From b6e070bf85c6f4fa204d36ae2d761db30b47d277 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Tue, 13 Feb 2024 20:41:17 +0200 Subject: [PATCH 60/81] Do not perform fast exit for catalog pages in redo filter (#6730) ## Problem See https://github.com/neondatabase/neon/issues/6674 Current implementation of `neon_redo_read_buffer_filter` performs fast exist for catalog pages: ``` /* * Out of an abundance of caution, we always run redo on shared catalogs, * regardless of whether the block is stored in shared buffers. See also * this function's top comment. */ if (!OidIsValid(NInfoGetDbOid(rinfo))) return false; */ as a result last written lsn and relation size for FSM fork are not correctly updated for catalog relations. ## Summary of changes Do not perform fast path return for catalog relations. ## Checklist before requesting a review - [ ] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist Co-authored-by: Konstantin Knizhnik --- pgxn/neon/pagestore_smgr.c | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 63e8b8dc1f..213e396328 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -3079,14 +3079,6 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id) XLogRecGetBlockTag(record, block_id, &rinfo, &forknum, &blkno); #endif - /* - * Out of an abundance of caution, we always run redo on shared catalogs, - * regardless of whether the block is stored in shared buffers. See also - * this function's top comment. - */ - if (!OidIsValid(NInfoGetDbOid(rinfo))) - return false; - CopyNRelFileInfoToBufTag(tag, rinfo); tag.forkNum = forknum; tag.blockNum = blkno; @@ -3100,17 +3092,28 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id) */ LWLockAcquire(partitionLock, LW_SHARED); - /* Try to find the relevant buffer */ - buffer = BufTableLookup(&tag, hash); - - no_redo_needed = buffer < 0; + /* + * Out of an abundance of caution, we always run redo on shared catalogs, + * regardless of whether the block is stored in shared buffers. See also + * this function's top comment. + */ + if (!OidIsValid(NInfoGetDbOid(rinfo))) + { + no_redo_needed = false; + } + else + { + /* Try to find the relevant buffer */ + buffer = BufTableLookup(&tag, hash); + no_redo_needed = buffer < 0; + } /* In both cases st lwlsn past this WAL record */ SetLastWrittenLSNForBlock(end_recptr, rinfo, forknum, blkno); /* * we don't have the buffer in memory, update lwLsn past this record, also - * evict page fro file cache + * evict page from file cache */ if (no_redo_needed) lfc_evict(rinfo, forknum, blkno); From ee7bbdda0e58af4350a6886544cd75f3cc1b2de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Wed, 14 Feb 2024 02:12:00 +0100 Subject: [PATCH 61/81] Create new metric for directory counts (#6736) There is O(n^2) issues due to how we store these directories (#6626), so it's good to keep an eye on them and ensure the numbers stay low. The new per-timeline metric `pageserver_directory_entries_count` isn't perfect, namely we don't calculate it every time we attach the timeline, but only if there is an actual change. Also, it is a collective metric over multiple scalars. Lastly, we only emit the metric if it is above a certain threshold. However, the metric still give a feel for the general size of the timeline. We care less for small values as the metric is mainly there to detect and track tenants with large directory counts. We also expose the directory counts in `TimelineInfo` so that one can get the detailed size distribution directly via the pageserver's API. Related: #6642 , https://github.com/neondatabase/cloud/issues/10273 --- libs/pageserver_api/src/models.rs | 2 + libs/pageserver_api/src/reltag.rs | 1 + pageserver/src/http/routes.rs | 1 + pageserver/src/metrics.rs | 34 +++++++++++++++- pageserver/src/pgdatadir_mapping.rs | 62 +++++++++++++++++++++++++++++ pageserver/src/tenant/timeline.rs | 39 +++++++++++++++++- test_runner/fixtures/metrics.py | 1 + 7 files changed, 137 insertions(+), 3 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 46324efd43..1226eaa312 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -494,6 +494,8 @@ pub struct TimelineInfo { pub current_logical_size: u64, pub current_logical_size_is_accurate: bool, + pub directory_entries_counts: Vec, + /// Sum of the size of all layer files. /// If a layer is present in both local FS and S3, it counts only once. pub current_physical_size: Option, // is None when timeline is Unloaded diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 8eb848a514..38693ab847 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -124,6 +124,7 @@ impl RelTag { Ord, strum_macros::EnumIter, strum_macros::FromRepr, + enum_map::Enum, )] #[repr(u8)] pub enum SlruKind { diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 4be8ee9892..c354cc9ab6 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -422,6 +422,7 @@ async fn build_timeline_info_common( tenant::timeline::logical_size::Accuracy::Approximate => false, tenant::timeline::logical_size::Accuracy::Exact => true, }, + directory_entries_counts: timeline.get_directory_metrics().to_vec(), current_physical_size, current_logical_size_non_incremental: None, timeline_dir_layer_file_size_sum: None, diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 98c98ef6e7..c2b1eafc3a 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -602,6 +602,15 @@ pub(crate) mod initial_logical_size { }); } +static DIRECTORY_ENTRIES_COUNT: Lazy = Lazy::new(|| { + register_uint_gauge_vec!( + "pageserver_directory_entries_count", + "Sum of the entries in pageserver-stored directory listings", + &["tenant_id", "shard_id", "timeline_id"] + ) + .expect("failed to define a metric") +}); + pub(crate) static TENANT_STATE_METRIC: Lazy = Lazy::new(|| { register_uint_gauge_vec!( "pageserver_tenant_states_count", @@ -1809,6 +1818,7 @@ pub(crate) struct TimelineMetrics { resident_physical_size_gauge: UIntGauge, /// copy of LayeredTimeline.current_logical_size pub current_logical_size_gauge: UIntGauge, + pub directory_entries_count_gauge: Lazy UIntGauge>>, pub num_persistent_files_created: IntCounter, pub persistent_bytes_written: IntCounter, pub evictions: IntCounter, @@ -1818,12 +1828,12 @@ pub(crate) struct TimelineMetrics { impl TimelineMetrics { pub fn new( tenant_shard_id: &TenantShardId, - timeline_id: &TimelineId, + timeline_id_raw: &TimelineId, evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder, ) -> Self { let tenant_id = tenant_shard_id.tenant_id.to_string(); let shard_id = format!("{}", tenant_shard_id.shard_slug()); - let timeline_id = timeline_id.to_string(); + let timeline_id = timeline_id_raw.to_string(); let flush_time_histo = StorageTimeMetrics::new( StorageTimeOperation::LayerFlush, &tenant_id, @@ -1876,6 +1886,22 @@ impl TimelineMetrics { let current_logical_size_gauge = CURRENT_LOGICAL_SIZE .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) .unwrap(); + // TODO use impl Trait syntax here once we have ability to use it: https://github.com/rust-lang/rust/issues/63065 + let directory_entries_count_gauge_closure = { + let tenant_shard_id = *tenant_shard_id; + let timeline_id_raw = *timeline_id_raw; + move || { + let tenant_id = tenant_shard_id.tenant_id.to_string(); + let shard_id = format!("{}", tenant_shard_id.shard_slug()); + let timeline_id = timeline_id_raw.to_string(); + let gauge: UIntGauge = DIRECTORY_ENTRIES_COUNT + .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) + .unwrap(); + gauge + } + }; + let directory_entries_count_gauge: Lazy UIntGauge>> = + Lazy::new(Box::new(directory_entries_count_gauge_closure)); let num_persistent_files_created = NUM_PERSISTENT_FILES_CREATED .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) .unwrap(); @@ -1902,6 +1928,7 @@ impl TimelineMetrics { last_record_gauge, resident_physical_size_gauge, current_logical_size_gauge, + directory_entries_count_gauge, num_persistent_files_created, persistent_bytes_written, evictions, @@ -1944,6 +1971,9 @@ impl Drop for TimelineMetrics { RESIDENT_PHYSICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]); } let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]); + if let Some(metric) = Lazy::get(&DIRECTORY_ENTRIES_COUNT) { + let _ = metric.remove_label_values(&[tenant_id, &shard_id, timeline_id]); + } let _ = NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, &shard_id, timeline_id]); let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, &shard_id, timeline_id]); diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index f1d18c0146..5f80ea9b5e 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -14,6 +14,7 @@ use crate::span::debug_assert_current_span_has_tenant_and_timeline_id_no_shard_i use crate::walrecord::NeonWalRecord; use anyhow::{ensure, Context}; use bytes::{Buf, Bytes, BytesMut}; +use enum_map::Enum; use pageserver_api::key::{ dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key, rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, @@ -155,6 +156,7 @@ impl Timeline { pending_updates: HashMap::new(), pending_deletions: Vec::new(), pending_nblocks: 0, + pending_directory_entries: Vec::new(), lsn, } } @@ -868,6 +870,7 @@ pub struct DatadirModification<'a> { pending_updates: HashMap>, pending_deletions: Vec<(Range, Lsn)>, pending_nblocks: i64, + pending_directory_entries: Vec<(DirectoryKind, usize)>, } impl<'a> DatadirModification<'a> { @@ -899,6 +902,7 @@ impl<'a> DatadirModification<'a> { let buf = DbDirectory::ser(&DbDirectory { dbdirs: HashMap::new(), })?; + self.pending_directory_entries.push((DirectoryKind::Db, 0)); self.put(DBDIR_KEY, Value::Image(buf.into())); // Create AuxFilesDirectory @@ -907,16 +911,24 @@ impl<'a> DatadirModification<'a> { let buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory { xids: HashSet::new(), })?; + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, 0)); self.put(TWOPHASEDIR_KEY, Value::Image(buf.into())); let buf: Bytes = SlruSegmentDirectory::ser(&SlruSegmentDirectory::default())?.into(); let empty_dir = Value::Image(buf); self.put(slru_dir_to_key(SlruKind::Clog), empty_dir.clone()); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::Clog), 0)); self.put( slru_dir_to_key(SlruKind::MultiXactMembers), empty_dir.clone(), ); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::Clog), 0)); self.put(slru_dir_to_key(SlruKind::MultiXactOffsets), empty_dir); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::MultiXactOffsets), 0)); Ok(()) } @@ -1017,6 +1029,7 @@ impl<'a> DatadirModification<'a> { let buf = RelDirectory::ser(&RelDirectory { rels: HashSet::new(), })?; + self.pending_directory_entries.push((DirectoryKind::Rel, 0)); self.put( rel_dir_to_key(spcnode, dbnode), Value::Image(Bytes::from(buf)), @@ -1039,6 +1052,8 @@ impl<'a> DatadirModification<'a> { if !dir.xids.insert(xid) { anyhow::bail!("twophase file for xid {} already exists", xid); } + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, dir.xids.len())); self.put( TWOPHASEDIR_KEY, Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)), @@ -1074,6 +1089,8 @@ impl<'a> DatadirModification<'a> { let mut dir = DbDirectory::des(&buf)?; if dir.dbdirs.remove(&(spcnode, dbnode)).is_some() { let buf = DbDirectory::ser(&dir)?; + self.pending_directory_entries + .push((DirectoryKind::Db, dir.dbdirs.len())); self.put(DBDIR_KEY, Value::Image(buf.into())); } else { warn!( @@ -1111,6 +1128,8 @@ impl<'a> DatadirModification<'a> { // Didn't exist. Update dbdir dbdir.dbdirs.insert((rel.spcnode, rel.dbnode), false); let buf = DbDirectory::ser(&dbdir).context("serialize db")?; + self.pending_directory_entries + .push((DirectoryKind::Db, dbdir.dbdirs.len())); self.put(DBDIR_KEY, Value::Image(buf.into())); // and create the RelDirectory @@ -1125,6 +1144,10 @@ impl<'a> DatadirModification<'a> { if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { return Err(RelationError::AlreadyExists); } + + self.pending_directory_entries + .push((DirectoryKind::Rel, rel_dir.rels.len())); + self.put( rel_dir_key, Value::Image(Bytes::from( @@ -1216,6 +1239,9 @@ impl<'a> DatadirModification<'a> { let buf = self.get(dir_key, ctx).await?; let mut dir = RelDirectory::des(&buf)?; + self.pending_directory_entries + .push((DirectoryKind::Rel, dir.rels.len())); + if dir.rels.remove(&(rel.relnode, rel.forknum)) { self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?))); } else { @@ -1251,6 +1277,8 @@ impl<'a> DatadirModification<'a> { if !dir.segments.insert(segno) { anyhow::bail!("slru segment {kind:?}/{segno} already exists"); } + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(kind), dir.segments.len())); self.put( dir_key, Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)), @@ -1295,6 +1323,8 @@ impl<'a> DatadirModification<'a> { if !dir.segments.remove(&segno) { warn!("slru segment {:?}/{} does not exist", kind, segno); } + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(kind), dir.segments.len())); self.put( dir_key, Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)), @@ -1325,6 +1355,8 @@ impl<'a> DatadirModification<'a> { if !dir.xids.remove(&xid) { warn!("twophase file for xid {} does not exist", xid); } + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, dir.xids.len())); self.put( TWOPHASEDIR_KEY, Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)), @@ -1340,6 +1372,8 @@ impl<'a> DatadirModification<'a> { let buf = AuxFilesDirectory::ser(&AuxFilesDirectory { files: HashMap::new(), })?; + self.pending_directory_entries + .push((DirectoryKind::AuxFiles, 0)); self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf))); Ok(()) } @@ -1366,6 +1400,9 @@ impl<'a> DatadirModification<'a> { } else { dir.files.insert(path, Bytes::copy_from_slice(content)); } + self.pending_directory_entries + .push((DirectoryKind::AuxFiles, dir.files.len())); + self.put( AUX_FILES_KEY, Value::Image(Bytes::from( @@ -1427,6 +1464,10 @@ impl<'a> DatadirModification<'a> { self.pending_nblocks = 0; } + for (kind, count) in std::mem::take(&mut self.pending_directory_entries) { + writer.update_directory_entries_count(kind, count as u64); + } + Ok(()) } @@ -1464,6 +1505,10 @@ impl<'a> DatadirModification<'a> { writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ)); } + for (kind, count) in std::mem::take(&mut self.pending_directory_entries) { + writer.update_directory_entries_count(kind, count as u64); + } + Ok(()) } @@ -1588,6 +1633,23 @@ struct SlruSegmentDirectory { segments: HashSet, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, enum_map::Enum)] +#[repr(u8)] +pub(crate) enum DirectoryKind { + Db, + TwoPhase, + Rel, + AuxFiles, + SlruSegment(SlruKind), +} + +impl DirectoryKind { + pub(crate) const KINDS_NUM: usize = ::LENGTH; + pub(crate) fn offset(&self) -> usize { + self.into_usize() + } +} + static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); #[allow(clippy::bool_assert_comparison)] diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 625be7a644..87cf0ac6ea 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -14,6 +14,7 @@ use enumset::EnumSet; use fail::fail_point; use futures::stream::StreamExt; use itertools::Itertools; +use once_cell::sync::Lazy; use pageserver_api::{ keyspace::{key_range_size, KeySpaceAccum}, models::{ @@ -34,17 +35,22 @@ use tokio_util::sync::CancellationToken; use tracing::*; use utils::sync::gate::Gate; -use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet}; use std::ops::{Deref, Range}; use std::pin::pin; use std::sync::atomic::Ordering as AtomicOrdering; use std::sync::{Arc, Mutex, RwLock, Weak}; use std::time::{Duration, Instant, SystemTime}; +use std::{ + array, + collections::{BTreeMap, BinaryHeap, HashMap, HashSet}, + sync::atomic::AtomicU64, +}; use std::{ cmp::{max, min, Ordering}, ops::ControlFlow, }; +use crate::pgdatadir_mapping::DirectoryKind; use crate::tenant::timeline::logical_size::CurrentLogicalSize; use crate::tenant::{ layer_map::{LayerMap, SearchResult}, @@ -258,6 +264,8 @@ pub struct Timeline { // in `crate::page_service` writes these metrics. pub(crate) query_metrics: crate::metrics::SmgrQueryTimePerTimeline, + directory_metrics: [AtomicU64; DirectoryKind::KINDS_NUM], + /// Ensures layers aren't frozen by checkpointer between /// [`Timeline::get_layer_for_write`] and layer reads. /// Locked automatically by [`TimelineWriter`] and checkpointer. @@ -790,6 +798,10 @@ impl Timeline { self.metrics.resident_physical_size_get() } + pub(crate) fn get_directory_metrics(&self) -> [u64; DirectoryKind::KINDS_NUM] { + array::from_fn(|idx| self.directory_metrics[idx].load(AtomicOrdering::Relaxed)) + } + /// /// Wait until WAL has been received and processed up to this LSN. /// @@ -1496,6 +1508,8 @@ impl Timeline { &timeline_id, ), + directory_metrics: array::from_fn(|_| AtomicU64::new(0)), + flush_loop_state: Mutex::new(FlushLoopState::NotStarted), layer_flush_start_tx, @@ -2264,6 +2278,29 @@ impl Timeline { } } + pub(crate) fn update_directory_entries_count(&self, kind: DirectoryKind, count: u64) { + self.directory_metrics[kind.offset()].store(count, AtomicOrdering::Relaxed); + let aux_metric = + self.directory_metrics[DirectoryKind::AuxFiles.offset()].load(AtomicOrdering::Relaxed); + + let sum_of_entries = self + .directory_metrics + .iter() + .map(|v| v.load(AtomicOrdering::Relaxed)) + .sum(); + // Set a high general threshold and a lower threshold for the auxiliary files, + // as we can have large numbers of relations in the db directory. + const SUM_THRESHOLD: u64 = 5000; + const AUX_THRESHOLD: u64 = 1000; + if sum_of_entries >= SUM_THRESHOLD || aux_metric >= AUX_THRESHOLD { + self.metrics + .directory_entries_count_gauge + .set(sum_of_entries); + } else if let Some(metric) = Lazy::get(&self.metrics.directory_entries_count_gauge) { + metric.set(sum_of_entries); + } + } + async fn find_layer(&self, layer_file_name: &str) -> Option { let guard = self.layers.read().await; for historic_layer in guard.layer_map().iter_historic_layers() { diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index ef41774289..418370c3ab 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -96,5 +96,6 @@ PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = ( "pageserver_evictions_total", "pageserver_evictions_with_low_residence_duration_total", *PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS, + # "pageserver_directory_entries_count", -- only used if above a certain threshold # "pageserver_broken_tenants_count" -- used only for broken ) From a5114a99b275b52fc7a512e62a7f80a5a103433d Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 14 Feb 2024 10:34:58 +0200 Subject: [PATCH 62/81] Create a symlink from pg_dynshmem to /dev/shm See included comment and issue https://github.com/neondatabase/autoscaling/issues/800 for details. This has no effect, unless you set "dynamic_shared_memory_type = mmap" in postgresql.conf. --- compute_tools/src/compute.rs | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 993b5725a4..83db8e09ec 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::env; use std::fs; use std::io::BufRead; -use std::os::unix::fs::PermissionsExt; +use std::os::unix::fs::{symlink, PermissionsExt}; use std::path::Path; use std::process::{Command, Stdio}; use std::str::FromStr; @@ -634,6 +634,48 @@ impl ComputeNode { // Update pg_hba.conf received with basebackup. update_pg_hba(pgdata_path)?; + // Place pg_dynshmem under /dev/shm. This allows us to use + // 'dynamic_shared_memory_type = mmap' so that the files are placed in + // /dev/shm, similar to how 'dynamic_shared_memory_type = posix' works. + // + // Why on earth don't we just stick to the 'posix' default, you might + // ask. It turns out that making large allocations with 'posix' doesn't + // work very well with autoscaling. The behavior we want is that: + // + // 1. You can make large DSM allocations, larger than the current RAM + // size of the VM, without errors + // + // 2. If the allocated memory is really used, the VM is scaled up + // automatically to accommodate that + // + // We try to make that possible by having swap in the VM. But with the + // default 'posix' DSM implementation, we fail step 1, even when there's + // plenty of swap available. PostgreSQL uses posix_fallocate() to create + // the shmem segment, which is really just a file in /dev/shm in Linux, + // but posix_fallocate() on tmpfs returns ENOMEM if the size is larger + // than available RAM. + // + // Using 'dynamic_shared_memory_type = mmap' works around that, because + // the Postgres 'mmap' DSM implementation doesn't use + // posix_fallocate(). Instead, it uses repeated calls to write(2) to + // fill the file with zeros. It's weird that that differs between + // 'posix' and 'mmap', but we take advantage of it. When the file is + // filled slowly with write(2), the kernel allows it to grow larger, as + // long as there's swap available. + // + // In short, using 'dynamic_shared_memory_type = mmap' allows us one DSM + // segment to be larger than currently available RAM. But because we + // don't want to store it on a real file, which the kernel would try to + // flush to disk, so symlink pg_dynshm to /dev/shm. + // + // We don't set 'dynamic_shared_memory_type = mmap' here, we let the + // control plane control that option. If 'mmap' is not used, this + // symlink doesn't affect anything. + // + // See https://github.com/neondatabase/autoscaling/issues/800 + std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?; + symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?; + match spec.mode { ComputeMode::Primary => {} ComputeMode::Replica | ComputeMode::Static(..) => { From a97b54e3b9e692532962d65b89b7e5f67a9c28a4 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 14 Feb 2024 10:35:59 +0200 Subject: [PATCH 63/81] Cherry-pick Postgres bugfix to 'mmap' DSM implementation Cherry-pick Upstream commit fbf9a7ac4d to neon stable branches. We'll get it in the next PostgreSQL minor release anyway, but we need it now, if we want to start using the 'mmap' implementation. See https://github.com/neondatabase/autoscaling/issues/800 for the plans on doing that. --- vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/revisions.json | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 018fb05201..9dd9956c55 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 018fb052011081dc2733d3118d12e5c36df6eba1 +Subproject commit 9dd9956c55ffbbd9abe77d10382453757fedfcf5 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 6ee78a3c29..ca2def9993 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 6ee78a3c29e33cafd85ba09568b6b5eb031d29b9 +Subproject commit ca2def999368d9df098a637234ad5a9003189463 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 550cdd26d4..9c37a49884 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 550cdd26d445afdd26b15aa93c8c2f3dc52f8361 +Subproject commit 9c37a4988463a97d9cacb321acf3828b09823269 diff --git a/vendor/revisions.json b/vendor/revisions.json index 91ebb8cb34..72bc0d7e0d 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { - "postgres-v16": "550cdd26d445afdd26b15aa93c8c2f3dc52f8361", - "postgres-v15": "6ee78a3c29e33cafd85ba09568b6b5eb031d29b9", - "postgres-v14": "018fb052011081dc2733d3118d12e5c36df6eba1" + "postgres-v16": "9c37a4988463a97d9cacb321acf3828b09823269", + "postgres-v15": "ca2def999368d9df098a637234ad5a9003189463", + "postgres-v14": "9dd9956c55ffbbd9abe77d10382453757fedfcf5" } From a9ec4eb4fc7777a529ff8c5ede814dd657390e58 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 14 Feb 2024 10:26:32 +0000 Subject: [PATCH 64/81] hold cancel session (#6750) ## Problem In a recent refactor, we accidentally dropped the cancel session early ## Summary of changes Hold the cancel session during proxy passthrough --- proxy/src/proxy.rs | 1 + proxy/src/proxy/passthrough.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ce77098a5f..8a9445303a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -331,6 +331,7 @@ pub async fn handle_client( compute: node, req: _request_gauge, conn: _client_gauge, + cancel: session, })) } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index c98f68d8d1..73c170fc0b 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,4 +1,5 @@ use crate::{ + cancellation, compute::PostgresConnection, console::messages::MetricsAuxInfo, metrics::NUM_BYTES_PROXIED_COUNTER, @@ -57,6 +58,7 @@ pub struct ProxyPassthrough { pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, + pub cancel: cancellation::Session, } impl ProxyPassthrough { From f39b0fce9b24a049208e74cc7d2a6b006b487839 Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 14 Feb 2024 10:57:01 +0000 Subject: [PATCH 65/81] Revert #6666 "tests: try to make restored-datadir comparison tests not flaky" (#6751) The #6666 change appears to have made the test fail more often. PR https://github.com/neondatabase/neon/pull/6712 should re-instate this change, along with its change to make the overall flow more reliable. This reverts commit 568f91420a9c677e77aeb736cb3f995a85f0b106. --- test_runner/fixtures/neon_fixtures.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 26f2b999a6..04af73c327 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3967,27 +3967,24 @@ def list_files_to_compare(pgdata_dir: Path) -> List[str]: # pg is the existing and running compute node, that we want to compare with a basebackup def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint: Endpoint): - pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) - # Get the timeline ID. We need it for the 'basebackup' command timeline_id = TimelineId(endpoint.safe_psql("SHOW neon.timeline_id")[0][0]) + # many tests already checkpoint, but do it just in case + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CHECKPOINT") + + # wait for pageserver to catch up + wait_for_last_flush_lsn(env, endpoint, endpoint.tenant_id, timeline_id) # stop postgres to ensure that files won't change endpoint.stop() - # Read the shutdown checkpoint's LSN - pg_controldata_path = os.path.join(pg_bin.pg_bin_path, "pg_controldata") - cmd = f"{pg_controldata_path} -D {endpoint.pgdata_dir}" - result = subprocess.run(cmd, capture_output=True, text=True, shell=True) - checkpoint_lsn = re.findall( - "Latest checkpoint location:\\s+([0-9A-F]+/[0-9A-F]+)", result.stdout - )[0] - log.debug(f"last checkpoint at {checkpoint_lsn}") - # Take a basebackup from pageserver restored_dir_path = env.repo_dir / f"{endpoint.endpoint_id}_restored_datadir" restored_dir_path.mkdir(exist_ok=True) + pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) psql_path = os.path.join(pg_bin.pg_bin_path, "psql") pageserver_id = env.attachment_service.locate(endpoint.tenant_id)[0]["node_id"] @@ -3995,7 +3992,7 @@ def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint {psql_path} \ --no-psqlrc \ postgres://localhost:{env.get_pageserver(pageserver_id).service_port.pg} \ - -c 'basebackup {endpoint.tenant_id} {timeline_id} {checkpoint_lsn}' \ + -c 'basebackup {endpoint.tenant_id} {timeline_id}' \ | tar -x -C {restored_dir_path} """ From df5d588f63fd329c701c37e61f77d9524ebcb19b Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 14 Feb 2024 15:22:41 +0100 Subject: [PATCH 66/81] refactor(VirtualFile::crashsafe_overwrite): avoid Handle::block_on in callers (#6731) Some callers of `VirtualFile::crashsafe_overwrite` call it on the executor thread, thereby potentially stalling it. Others are more diligent and wrap it in `spawn_blocking(..., Handle::block_on, ... )` to avoid stalling the executor thread. However, because `crashsafe_overwrite` uses VirtualFile::open_with_options internally, we spawn a new thread-local `tokio-epoll-uring::System` in the blocking pool thread that's used for the `spawn_blocking` call. This PR refactors the situation such that we do the `spawn_blocking` inside `VirtualFile::crashsafe_overwrite`. This unifies the situation for the better: 1. Callers who didn't wrap in `spawn_blocking(..., Handle::block_on, ...)` before no longer stall the executor. 2. Callers who did it before now can avoid the `block_on`, resolving the problem with the short-lived `tokio-epoll-uring::System`s in the blocking pool threads. A future PR will build on top of this and divert to tokio-epoll-uring if it's configures as the IO engine. Changes ------- - Convert implementation to std::fs and move it into `crashsafe.rs` - Yes, I know, Safekeepers (cc @arssher ) added `durable_rename` and `fsync_async_opt` recently. However, `crashsafe_overwrite` is different in the sense that it's higher level, i.e., it's more like `std::fs::write` and the Safekeeper team's code is more building block style. - The consequence is that we don't use the VirtualFile file descriptor cache anymore. - I don't think it's a big deal because we have plenty of slack wrt production file descriptor limit rlimit (see [this dashboard](https://neonprod.grafana.net/d/e4a40325-9acf-4aa0-8fd9-f6322b3f30bd/pageserver-open-file-descriptors?orgId=1)) - Use `tokio::task::spawn_blocking` in `VirtualFile::crashsafe_overwrite` to call the new `crashsafe::overwrite` API. - Inspect all callers to remove any double-`spawn_blocking` - spawn_blocking requires the captures data to be 'static + Send. So, refactor the callers. We'll need this for future tokio-epoll-uring support anyway, because tokio-epoll-uring requires owned buffers. Related Issues -------------- - overall epic to enable write path to tokio-epoll-uring: #6663 - this is also kind of relevant to the tokio-epoll-uring System creation failures that we encountered in staging, investigation being tracked in #6667 - why is it relevant? Because this PR removes two uses of `spawn_blocking+Handle::block_on` --- libs/utils/src/crashsafe.rs | 44 +++++++++++- pageserver/src/deletion_queue.rs | 5 +- pageserver/src/tenant.rs | 33 +++------ pageserver/src/tenant/metadata.rs | 2 +- pageserver/src/tenant/secondary/downloader.rs | 11 +-- pageserver/src/virtual_file.rs | 72 ++++++++----------- 6 files changed, 89 insertions(+), 78 deletions(-) diff --git a/libs/utils/src/crashsafe.rs b/libs/utils/src/crashsafe.rs index 1c72e9cae9..756b19138c 100644 --- a/libs/utils/src/crashsafe.rs +++ b/libs/utils/src/crashsafe.rs @@ -1,7 +1,7 @@ use std::{ borrow::Cow, fs::{self, File}, - io, + io::{self, Write}, }; use camino::{Utf8Path, Utf8PathBuf}; @@ -161,6 +161,48 @@ pub async fn durable_rename( Ok(()) } +/// Writes a file to the specified `final_path` in a crash safe fasion, using [`std::fs`]. +/// +/// The file is first written to the specified `tmp_path`, and in a second +/// step, the `tmp_path` is renamed to the `final_path`. Intermediary fsync +/// and atomic rename guarantee that, if we crash at any point, there will never +/// be a partially written file at `final_path` (but maybe at `tmp_path`). +/// +/// Callers are responsible for serializing calls of this function for a given `final_path`. +/// If they don't, there may be an error due to conflicting `tmp_path`, or there will +/// be no error and the content of `final_path` will be the "winner" caller's `content`. +/// I.e., the atomticity guarantees still hold. +pub fn overwrite( + final_path: &Utf8Path, + tmp_path: &Utf8Path, + content: &[u8], +) -> std::io::Result<()> { + let Some(final_path_parent) = final_path.parent() else { + return Err(std::io::Error::from_raw_os_error( + nix::errno::Errno::EINVAL as i32, + )); + }; + std::fs::remove_file(tmp_path).or_else(crate::fs_ext::ignore_not_found)?; + let mut file = std::fs::OpenOptions::new() + .write(true) + // Use `create_new` so that, if we race with ourselves or something else, + // we bail out instead of causing damage. + .create_new(true) + .open(tmp_path)?; + file.write_all(content)?; + file.sync_all()?; + drop(file); // don't keep the fd open for longer than we have to + + std::fs::rename(tmp_path, final_path)?; + + let final_parent_dirfd = std::fs::OpenOptions::new() + .read(true) + .open(final_path_parent)?; + + final_parent_dirfd.sync_all()?; + Ok(()) +} + #[cfg(test)] mod tests { diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index 9046fe881b..e0c40ea1b0 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -234,7 +234,7 @@ impl DeletionHeader { let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?; let header_path = conf.deletion_header_path(); let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX); - VirtualFile::crashsafe_overwrite(&header_path, &temp_path, header_bytes) + VirtualFile::crashsafe_overwrite(header_path, temp_path, header_bytes) .await .maybe_fatal_err("save deletion header")?; @@ -325,7 +325,8 @@ impl DeletionList { let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX); let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list"); - VirtualFile::crashsafe_overwrite(&path, &temp_path, bytes) + + VirtualFile::crashsafe_overwrite(path, temp_path, bytes) .await .maybe_fatal_err("save deletion list") .map_err(Into::into) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 9f1f188bf2..1f3bc13472 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -28,7 +28,6 @@ use remote_storage::GenericRemoteStorage; use std::fmt; use storage_broker::BrokerClientChannel; use tokio::io::BufReader; -use tokio::runtime::Handle; use tokio::sync::watch; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -2878,17 +2877,10 @@ impl Tenant { let tenant_shard_id = *tenant_shard_id; let config_path = config_path.to_owned(); - tokio::task::spawn_blocking(move || { - Handle::current().block_on(async move { - let conf_content = conf_content.into_bytes(); - VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content) - .await - .with_context(|| { - format!("write tenant {tenant_shard_id} config to {config_path}") - }) - }) - }) - .await??; + let conf_content = conf_content.into_bytes(); + VirtualFile::crashsafe_overwrite(config_path.clone(), temp_path, conf_content) + .await + .with_context(|| format!("write tenant {tenant_shard_id} config to {config_path}"))?; Ok(()) } @@ -2915,17 +2907,12 @@ impl Tenant { let tenant_shard_id = *tenant_shard_id; let target_config_path = target_config_path.to_owned(); - tokio::task::spawn_blocking(move || { - Handle::current().block_on(async move { - let conf_content = conf_content.into_bytes(); - VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content) - .await - .with_context(|| { - format!("write tenant {tenant_shard_id} config to {target_config_path}") - }) - }) - }) - .await??; + let conf_content = conf_content.into_bytes(); + VirtualFile::crashsafe_overwrite(target_config_path.clone(), temp_path, conf_content) + .await + .with_context(|| { + format!("write tenant {tenant_shard_id} config to {target_config_path}") + })?; Ok(()) } diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index dcbe781f90..233acfd431 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -279,7 +279,7 @@ pub async fn save_metadata( let path = conf.metadata_path(tenant_shard_id, timeline_id); let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX); let metadata_bytes = data.to_bytes().context("serialize metadata")?; - VirtualFile::crashsafe_overwrite(&path, &temp_path, metadata_bytes) + VirtualFile::crashsafe_overwrite(path, temp_path, metadata_bytes) .await .context("write metadata")?; Ok(()) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index c23416a7f0..c8288acc20 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -484,14 +484,9 @@ impl<'a> TenantDownloader<'a> { let temp_path = path_with_suffix_extension(&heatmap_path, TEMP_FILE_SUFFIX); let context_msg = format!("write tenant {tenant_shard_id} heatmap to {heatmap_path}"); let heatmap_path_bg = heatmap_path.clone(); - tokio::task::spawn_blocking(move || { - tokio::runtime::Handle::current().block_on(async move { - VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, heatmap_bytes).await - }) - }) - .await - .expect("Blocking task is never aborted") - .maybe_fatal_err(&context_msg)?; + VirtualFile::crashsafe_overwrite(heatmap_path_bg, temp_path, heatmap_bytes) + .await + .maybe_fatal_err(&context_msg)?; tracing::debug!("Wrote local heatmap to {}", heatmap_path); diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 6cff748d42..2a8c22430b 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -19,14 +19,13 @@ use once_cell::sync::OnceCell; use pageserver_api::shard::TenantShardId; use std::fs::{self, File}; use std::io::{Error, ErrorKind, Seek, SeekFrom}; -use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice}; +use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::fs::FileExt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use tokio::time::Instant; -use utils::fs_ext; pub use pageserver_api::models::virtual_file as api; pub(crate) mod io_engine; @@ -404,47 +403,34 @@ impl VirtualFile { Ok(vfile) } - /// Writes a file to the specified `final_path` in a crash safe fasion + /// Async version of [`::utils::crashsafe::overwrite`]. /// - /// The file is first written to the specified tmp_path, and in a second - /// step, the tmp path is renamed to the final path. As renames are - /// atomic, a crash during the write operation will never leave behind a - /// partially written file. - pub async fn crashsafe_overwrite( - final_path: &Utf8Path, - tmp_path: &Utf8Path, + /// # NB: + /// + /// Doesn't actually use the [`VirtualFile`] file descriptor cache, but, + /// it did at an earlier time. + /// And it will use this module's [`io_engine`] in the near future, so, leaving it here. + pub async fn crashsafe_overwrite + Send, Buf: IoBuf + Send>( + final_path: Utf8PathBuf, + tmp_path: Utf8PathBuf, content: B, ) -> std::io::Result<()> { - let Some(final_path_parent) = final_path.parent() else { - return Err(std::io::Error::from_raw_os_error( - nix::errno::Errno::EINVAL as i32, - )); - }; - std::fs::remove_file(tmp_path).or_else(fs_ext::ignore_not_found)?; - let mut file = Self::open_with_options( - tmp_path, - OpenOptions::new() - .write(true) - // Use `create_new` so that, if we race with ourselves or something else, - // we bail out instead of causing damage. - .create_new(true), - ) - .await?; - let (_content, res) = file.write_all(content).await; - res?; - file.sync_all().await?; - drop(file); // before the rename, that's important! - // renames are atomic - std::fs::rename(tmp_path, final_path)?; - // Only open final path parent dirfd now, so that this operation only - // ever holds one VirtualFile fd at a time. That's important because - // the current `find_victim_slot` impl might pick the same slot for both - // VirtualFile., and it eventually does a blocking write lock instead of - // try_lock. - let final_parent_dirfd = - Self::open_with_options(final_path_parent, OpenOptions::new().read(true)).await?; - final_parent_dirfd.sync_all().await?; - Ok(()) + // TODO: use tokio_epoll_uring if configured as `io_engine`. + // See https://github.com/neondatabase/neon/issues/6663 + + tokio::task::spawn_blocking(move || { + let slice_storage; + let content_len = content.bytes_init(); + let content = if content.bytes_init() > 0 { + slice_storage = Some(content.slice(0..content_len)); + slice_storage.as_deref().expect("just set it to Some()") + } else { + &[] + }; + utils::crashsafe::overwrite(&final_path, &tmp_path, content) + }) + .await + .expect("blocking task is never aborted") } /// Call File::sync_all() on the underlying File. @@ -1315,7 +1301,7 @@ mod tests { let path = testdir.join("myfile"); let tmp_path = testdir.join("myfile.tmp"); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) + VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1324,7 +1310,7 @@ mod tests { assert!(!tmp_path.exists()); drop(file); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar".to_vec()) + VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1346,7 +1332,7 @@ mod tests { std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap(); assert!(tmp_path.exists()); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) + VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) .await .unwrap(); From 774a6e74757d1b1d1e3c75ab103bdd38587a38f1 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 14 Feb 2024 15:59:06 +0100 Subject: [PATCH 67/81] refactor(virtual_file) make write_all_at take owned buffers (#6673) context: https://github.com/neondatabase/neon/issues/6663 Building atop #6664, this PR switches `write_all_at` to take owned buffers. The main challenge here is the `EphemeralFile::mutable_tail`, for which I'm picking the ugly solution of an `Option` that is `None` while the IO is in flight. After this, we will be able to switch `write_at` to take owned buffers and call tokio-epoll-uring's `write` function with that owned buffer. That'll be done in #6378. --- pageserver/src/tenant/ephemeral_file.rs | 51 ++++++++++++++++++------- pageserver/src/virtual_file.rs | 50 +++++++++++++++++------- 2 files changed, 74 insertions(+), 27 deletions(-) diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 6b8cd77d78..2bedbf7f61 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -6,6 +6,7 @@ use crate::context::RequestContext; use crate::page_cache::{self, PAGE_SZ}; use crate::tenant::block_io::{BlockCursor, BlockLease, BlockReader}; use crate::virtual_file::{self, VirtualFile}; +use bytes::BytesMut; use camino::Utf8PathBuf; use pageserver_api::shard::TenantShardId; use std::cmp::min; @@ -26,7 +27,10 @@ pub struct EphemeralFile { /// An ephemeral file is append-only. /// We keep the last page, which can still be modified, in [`Self::mutable_tail`]. /// The other pages, which can no longer be modified, are accessed through the page cache. - mutable_tail: [u8; PAGE_SZ], + /// + /// None <=> IO is ongoing. + /// Size is fixed to PAGE_SZ at creation time and must not be changed. + mutable_tail: Option, } impl EphemeralFile { @@ -60,7 +64,7 @@ impl EphemeralFile { _timeline_id: timeline_id, file, len: 0, - mutable_tail: [0u8; PAGE_SZ], + mutable_tail: Some(BytesMut::zeroed(PAGE_SZ)), }) } @@ -103,7 +107,13 @@ impl EphemeralFile { }; } else { debug_assert_eq!(blknum as u64, self.len / PAGE_SZ as u64); - Ok(BlockLease::EphemeralFileMutableTail(&self.mutable_tail)) + Ok(BlockLease::EphemeralFileMutableTail( + self.mutable_tail + .as_deref() + .expect("we're not doing IO, it must be Some()") + .try_into() + .expect("we ensure that it's always PAGE_SZ"), + )) } } @@ -135,21 +145,27 @@ impl EphemeralFile { ) -> Result<(), io::Error> { let mut src_remaining = src; while !src_remaining.is_empty() { - let dst_remaining = &mut self.ephemeral_file.mutable_tail[self.off..]; + let dst_remaining = &mut self + .ephemeral_file + .mutable_tail + .as_deref_mut() + .expect("IO is not yet ongoing")[self.off..]; let n = min(dst_remaining.len(), src_remaining.len()); dst_remaining[..n].copy_from_slice(&src_remaining[..n]); self.off += n; src_remaining = &src_remaining[n..]; if self.off == PAGE_SZ { - match self + let mutable_tail = std::mem::take(&mut self.ephemeral_file.mutable_tail) + .expect("IO is not yet ongoing"); + let (mutable_tail, res) = self .ephemeral_file .file - .write_all_at( - &self.ephemeral_file.mutable_tail, - self.blknum as u64 * PAGE_SZ as u64, - ) - .await - { + .write_all_at(mutable_tail, self.blknum as u64 * PAGE_SZ as u64) + .await; + // TODO: If we panic before we can put the mutable_tail back, subsequent calls will fail. + // I.e., the IO isn't retryable if we panic. + self.ephemeral_file.mutable_tail = Some(mutable_tail); + match res { Ok(_) => { // Pre-warm the page cache with what we just wrote. // This isn't necessary for coherency/correctness, but it's how we've always done it. @@ -169,7 +185,12 @@ impl EphemeralFile { Ok(page_cache::ReadBufResult::NotFound(mut write_guard)) => { let buf: &mut [u8] = write_guard.deref_mut(); debug_assert_eq!(buf.len(), PAGE_SZ); - buf.copy_from_slice(&self.ephemeral_file.mutable_tail); + buf.copy_from_slice( + self.ephemeral_file + .mutable_tail + .as_deref() + .expect("IO is not ongoing"), + ); let _ = write_guard.mark_valid(); // pre-warm successful } @@ -181,7 +202,11 @@ impl EphemeralFile { // Zero the buffer for re-use. // Zeroing is critical for correcntess because the write_blob code below // and similarly read_blk expect zeroed pages. - self.ephemeral_file.mutable_tail.fill(0); + self.ephemeral_file + .mutable_tail + .as_deref_mut() + .expect("IO is not ongoing") + .fill(0); // This block is done, move to next one. self.blknum += 1; self.off = 0; diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 2a8c22430b..858fc0ef64 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -568,24 +568,37 @@ impl VirtualFile { } // Copied from https://doc.rust-lang.org/1.72.0/src/std/os/unix/fs.rs.html#219-235 - pub async fn write_all_at(&self, mut buf: &[u8], mut offset: u64) -> Result<(), Error> { + pub async fn write_all_at( + &self, + buf: B, + mut offset: u64, + ) -> (B::Buf, Result<(), Error>) { + let buf_len = buf.bytes_init(); + if buf_len == 0 { + return (Slice::into_inner(buf.slice_full()), Ok(())); + } + let mut buf = buf.slice(0..buf_len); while !buf.is_empty() { - match self.write_at(buf, offset).await { + // TODO: push `buf` further down + match self.write_at(&buf, offset).await { Ok(0) => { - return Err(Error::new( - std::io::ErrorKind::WriteZero, - "failed to write whole buffer", - )); + return ( + Slice::into_inner(buf), + Err(Error::new( + std::io::ErrorKind::WriteZero, + "failed to write whole buffer", + )), + ); } Ok(n) => { - buf = &buf[n..]; + buf = buf.slice(n..); offset += n as u64; } Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} - Err(e) => return Err(e), + Err(e) => return (Slice::into_inner(buf), Err(e)), } } - Ok(()) + (Slice::into_inner(buf), Ok(())) } /// Writes `buf.slice(0..buf.bytes_init())`. @@ -1050,10 +1063,19 @@ mod tests { MaybeVirtualFile::File(file) => file.read_exact_at(&mut buf, offset).map(|()| buf), } } - async fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> { + async fn write_all_at(&self, buf: B, offset: u64) -> Result<(), Error> { match self { - MaybeVirtualFile::VirtualFile(file) => file.write_all_at(buf, offset).await, - MaybeVirtualFile::File(file) => file.write_all_at(buf, offset), + MaybeVirtualFile::VirtualFile(file) => { + let (_buf, res) = file.write_all_at(buf, offset).await; + res + } + MaybeVirtualFile::File(file) => { + let buf_len = buf.bytes_init(); + if buf_len == 0 { + return Ok(()); + } + file.write_all_at(&buf.slice(0..buf_len), offset) + } } } async fn seek(&mut self, pos: SeekFrom) -> Result { @@ -1200,8 +1222,8 @@ mod tests { .to_owned(), ) .await?; - file_b.write_all_at(b"BAR", 3).await?; - file_b.write_all_at(b"FOO", 0).await?; + file_b.write_all_at(b"BAR".to_vec(), 3).await?; + file_b.write_all_at(b"FOO".to_vec(), 0).await?; assert_eq!(file_b.read_string_at(2, 3).await?, "OBA"); From 840abe395413508db40d0428e30f09343c051fed Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 14 Feb 2024 15:01:16 +0000 Subject: [PATCH 68/81] pageserver: store aux files as deltas (#6742) ## Problem Aux files were stored with an O(N^2) cost, since on each modification the entire map is re-written as a page image. This addresses one axis of the inefficiency in logical replication's use of storage (https://github.com/neondatabase/neon/issues/6626). It will still be writing a large amount of duplicative data if writing the same slot's state every 15 seconds, but the impact will be O(N) instead of O(N^2). ## Summary of changes - Introduce `NeonWalRecord::AuxFile` - In `DatadirModification`, if the AUX_FILES_KEY has already been set, then write a delta instead of an image --- pageserver/src/pgdatadir_mapping.rs | 162 +++++++++++++++++++++++---- pageserver/src/tenant.rs | 41 ++++--- pageserver/src/walrecord.rs | 5 + pageserver/src/walredo.rs | 2 +- pageserver/src/walredo/apply_neon.rs | 70 +++++++++++- 5 files changed, 242 insertions(+), 38 deletions(-) diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 5f80ea9b5e..0ff03303d4 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -156,6 +156,7 @@ impl Timeline { pending_updates: HashMap::new(), pending_deletions: Vec::new(), pending_nblocks: 0, + pending_aux_files: None, pending_directory_entries: Vec::new(), lsn, } @@ -870,6 +871,14 @@ pub struct DatadirModification<'a> { pending_updates: HashMap>, pending_deletions: Vec<(Range, Lsn)>, pending_nblocks: i64, + + // If we already wrote any aux file changes in this modification, stash the latest dir. If set, + // [`Self::put_file`] may assume that it is safe to emit a delta rather than checking + // if AUX_FILES_KEY is already set. + pending_aux_files: Option, + + /// For special "directory" keys that store key-value maps, track the size of the map + /// if it was updated in this modification. pending_directory_entries: Vec<(DirectoryKind, usize)>, } @@ -1384,31 +1393,76 @@ impl<'a> DatadirModification<'a> { content: &[u8], ctx: &RequestContext, ) -> anyhow::Result<()> { - let mut dir = match self.get(AUX_FILES_KEY, ctx).await { - Ok(buf) => AuxFilesDirectory::des(&buf)?, - Err(e) => { - // This is expected: historical databases do not have the key. - debug!("Failed to get info about AUX files: {}", e); - AuxFilesDirectory { - files: HashMap::new(), + let file_path = path.to_string(); + let content = if content.is_empty() { + None + } else { + Some(Bytes::copy_from_slice(content)) + }; + + let dir = if let Some(mut dir) = self.pending_aux_files.take() { + // We already updated aux files in `self`: emit a delta and update our latest value + + self.put( + AUX_FILES_KEY, + Value::WalRecord(NeonWalRecord::AuxFile { + file_path: file_path.clone(), + content: content.clone(), + }), + ); + + dir.upsert(file_path, content); + dir + } else { + // Check if the AUX_FILES_KEY is initialized + match self.get(AUX_FILES_KEY, ctx).await { + Ok(dir_bytes) => { + let mut dir = AuxFilesDirectory::des(&dir_bytes)?; + // Key is already set, we may append a delta + self.put( + AUX_FILES_KEY, + Value::WalRecord(NeonWalRecord::AuxFile { + file_path: file_path.clone(), + content: content.clone(), + }), + ); + dir.upsert(file_path, content); + dir + } + Err( + e @ (PageReconstructError::AncestorStopping(_) + | PageReconstructError::Cancelled + | PageReconstructError::AncestorLsnTimeout(_)), + ) => { + // Important that we do not interpret a shutdown error as "not found" and thereby + // reset the map. + return Err(e.into()); + } + // FIXME: PageReconstructError doesn't have an explicit variant for key-not-found, so + // we are assuming that all _other_ possible errors represents a missing key. If some + // other error occurs, we may incorrectly reset the map of aux files. + Err(PageReconstructError::Other(_) | PageReconstructError::WalRedo(_)) => { + // Key is missing, we must insert an image as the basis for subsequent deltas. + + let mut dir = AuxFilesDirectory { + files: HashMap::new(), + }; + dir.upsert(file_path, content); + self.put( + AUX_FILES_KEY, + Value::Image(Bytes::from( + AuxFilesDirectory::ser(&dir).context("serialize")?, + )), + ); + dir } } }; - let path = path.to_string(); - if content.is_empty() { - dir.files.remove(&path); - } else { - dir.files.insert(path, Bytes::copy_from_slice(content)); - } + self.pending_directory_entries .push((DirectoryKind::AuxFiles, dir.files.len())); + self.pending_aux_files = Some(dir); - self.put( - AUX_FILES_KEY, - Value::Image(Bytes::from( - AuxFilesDirectory::ser(&dir).context("serialize")?, - )), - ); Ok(()) } @@ -1618,8 +1672,18 @@ struct RelDirectory { } #[derive(Debug, Serialize, Deserialize, Default)] -struct AuxFilesDirectory { - files: HashMap, +pub(crate) struct AuxFilesDirectory { + pub(crate) files: HashMap, +} + +impl AuxFilesDirectory { + pub(crate) fn upsert(&mut self, key: String, value: Option) { + if let Some(value) = value { + self.files.insert(key, value); + } else { + self.files.remove(&key); + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -1655,8 +1719,60 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); #[allow(clippy::bool_assert_comparison)] #[cfg(test)] mod tests { - //use super::repo_harness::*; - //use super::*; + use hex_literal::hex; + use utils::id::TimelineId; + + use super::*; + + use crate::{tenant::harness::TenantHarness, DEFAULT_PG_VERSION}; + + /// Test a round trip of aux file updates, from DatadirModification to reading back from the Timeline + #[tokio::test] + async fn aux_files_round_trip() -> anyhow::Result<()> { + let name = "aux_files_round_trip"; + let harness = TenantHarness::create(name)?; + + pub const TIMELINE_ID: TimelineId = + TimelineId::from_array(hex!("11223344556677881122334455667788")); + + let (tenant, ctx) = harness.load().await; + let tline = tenant + .create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx) + .await?; + let tline = tline.raw_timeline().unwrap(); + + // First modification: insert two keys + let mut modification = tline.begin_modification(Lsn(0x1000)); + modification.put_file("foo/bar1", b"content1", &ctx).await?; + modification.set_lsn(Lsn(0x1008))?; + modification.put_file("foo/bar2", b"content2", &ctx).await?; + modification.commit(&ctx).await?; + let expect_1008 = HashMap::from([ + ("foo/bar1".to_string(), Bytes::from_static(b"content1")), + ("foo/bar2".to_string(), Bytes::from_static(b"content2")), + ]); + + let readback = tline.list_aux_files(Lsn(0x1008), &ctx).await?; + assert_eq!(readback, expect_1008); + + // Second modification: update one key, remove the other + let mut modification = tline.begin_modification(Lsn(0x2000)); + modification.put_file("foo/bar1", b"content3", &ctx).await?; + modification.set_lsn(Lsn(0x2008))?; + modification.put_file("foo/bar2", b"", &ctx).await?; + modification.commit(&ctx).await?; + let expect_2008 = + HashMap::from([("foo/bar1".to_string(), Bytes::from_static(b"content3"))]); + + let readback = tline.list_aux_files(Lsn(0x2008), &ctx).await?; + assert_eq!(readback, expect_2008); + + // Reading back in time works + let readback = tline.list_aux_files(Lsn(0x1008), &ctx).await?; + assert_eq!(readback, expect_1008); + + Ok(()) + } /* fn assert_current_logical_size(timeline: &DatadirTimeline, lsn: Lsn) { diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 1f3bc13472..44a446d697 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -3901,6 +3901,7 @@ pub(crate) mod harness { use utils::lsn::Lsn; use crate::deletion_queue::mock::MockDeletionQueue; + use crate::walredo::apply_neon; use crate::{ config::PageServerConf, repository::Key, tenant::Tenant, walrecord::NeonWalRecord, }; @@ -4160,20 +4161,34 @@ pub(crate) mod harness { records: Vec<(Lsn, NeonWalRecord)>, _pg_version: u32, ) -> anyhow::Result { - let s = format!( - "redo for {} to get to {}, with {} and {} records", - key, - lsn, - if base_img.is_some() { - "base image" - } else { - "no base image" - }, - records.len() - ); - println!("{s}"); + let records_neon = records.iter().all(|r| apply_neon::can_apply_in_neon(&r.1)); - Ok(TEST_IMG(&s)) + if records_neon { + // For Neon wal records, we can decode without spawning postgres, so do so. + let base_img = base_img.expect("Neon WAL redo requires base image").1; + let mut page = BytesMut::new(); + page.extend_from_slice(&base_img); + for (_record_lsn, record) in records { + apply_neon::apply_in_neon(&record, key, &mut page)?; + } + Ok(page.freeze()) + } else { + // We never spawn a postgres walredo process in unit tests: just log what we might have done. + let s = format!( + "redo for {} to get to {}, with {} and {} records", + key, + lsn, + if base_img.is_some() { + "base image" + } else { + "no base image" + }, + records.len() + ); + println!("{s}"); + + Ok(TEST_IMG(&s)) + } } } } diff --git a/pageserver/src/walrecord.rs b/pageserver/src/walrecord.rs index ff6bc9194b..1b7777a544 100644 --- a/pageserver/src/walrecord.rs +++ b/pageserver/src/walrecord.rs @@ -44,6 +44,11 @@ pub enum NeonWalRecord { moff: MultiXactOffset, members: Vec, }, + /// Update the map of AUX files, either writing or dropping an entry + AuxFile { + file_path: String, + content: Option, + }, } impl NeonWalRecord { diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index 98a6a0bb6c..35cbefb92c 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -22,7 +22,7 @@ mod process; /// Code to apply [`NeonWalRecord`]s. -mod apply_neon; +pub(crate) mod apply_neon; use crate::config::PageServerConf; use crate::metrics::{ diff --git a/pageserver/src/walredo/apply_neon.rs b/pageserver/src/walredo/apply_neon.rs index 52899349c4..6ce90e0c47 100644 --- a/pageserver/src/walredo/apply_neon.rs +++ b/pageserver/src/walredo/apply_neon.rs @@ -1,7 +1,8 @@ +use crate::pgdatadir_mapping::AuxFilesDirectory; use crate::walrecord::NeonWalRecord; use anyhow::Context; use byteorder::{ByteOrder, LittleEndian}; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut}; use pageserver_api::key::{key_to_rel_block, key_to_slru_block, Key}; use pageserver_api::reltag::SlruKind; use postgres_ffi::pg_constants; @@ -12,6 +13,7 @@ use postgres_ffi::v14::nonrelfile_utils::{ }; use postgres_ffi::BLCKSZ; use tracing::*; +use utils::bin_ser::BeSer; /// Can this request be served by neon redo functions /// or we need to pass it to wal-redo postgres process? @@ -230,6 +232,72 @@ pub(crate) fn apply_in_neon( LittleEndian::write_u32(&mut page[memberoff..memberoff + 4], member.xid); } } + NeonWalRecord::AuxFile { file_path, content } => { + let mut dir = AuxFilesDirectory::des(page)?; + dir.upsert(file_path.clone(), content.clone()); + + page.clear(); + let mut writer = page.writer(); + dir.ser_into(&mut writer)?; + } } Ok(()) } + +#[cfg(test)] +mod test { + use bytes::Bytes; + use pageserver_api::key::AUX_FILES_KEY; + + use super::*; + use std::collections::HashMap; + + use crate::{pgdatadir_mapping::AuxFilesDirectory, walrecord::NeonWalRecord}; + + /// Test [`apply_in_neon`]'s handling of NeonWalRecord::AuxFile + #[test] + fn apply_aux_file_deltas() -> anyhow::Result<()> { + let base_dir = AuxFilesDirectory { + files: HashMap::from([ + ("two".to_string(), Bytes::from_static(b"content0")), + ("three".to_string(), Bytes::from_static(b"contentX")), + ]), + }; + let base_image = AuxFilesDirectory::ser(&base_dir)?; + + let deltas = vec![ + // Insert + NeonWalRecord::AuxFile { + file_path: "one".to_string(), + content: Some(Bytes::from_static(b"content1")), + }, + // Update + NeonWalRecord::AuxFile { + file_path: "two".to_string(), + content: Some(Bytes::from_static(b"content99")), + }, + // Delete + NeonWalRecord::AuxFile { + file_path: "three".to_string(), + content: None, + }, + ]; + + let file_path = AUX_FILES_KEY; + let mut page = BytesMut::from_iter(base_image); + + for record in deltas { + apply_in_neon(&record, file_path, &mut page)?; + } + + let reconstructed = AuxFilesDirectory::des(&page)?; + let expect = HashMap::from([ + ("one".to_string(), Bytes::from_static(b"content1")), + ("two".to_string(), Bytes::from_static(b"content99")), + ]); + + assert_eq!(reconstructed.files, expect); + + Ok(()) + } +} From 7d3cdc05d486ee1a1ef5ec8d7137949bcf7d036e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 14 Feb 2024 18:01:15 +0100 Subject: [PATCH 69/81] fix(pageserver): pagebench doesn't work with released artifacts (#6757) The canonical release artifact of neon.git is the Docker image with all the binaries in them: ``` docker pull neondatabase/neon:release-4854 docker create --name extract neondatabase/neon:release-4854 docker cp extract:/usr/local/bin/pageserver ./pageserver.release-4854 chmod +x pageserver.release-4854 cp -a pageserver.release-4854 ./target/release/pageserver ``` Before this PR, these artifacts didn't expose the `keyspace` API, thereby preventing `pagebench get-page-latest-lsn` from working. Having working pagebench is useful, e.g., for experiments in staging. So, expose the API, but don't document it, as it's not part of the interface with control plane. --- pageserver/src/http/routes.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index c354cc9ab6..ab546c873a 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -2214,7 +2214,7 @@ pub fn make_router( ) .get( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/keyspace", - |r| testing_api_handler("read out the keyspace", r, timeline_collect_keyspace), + |r| api_handler(r, timeline_collect_keyspace), ) .put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler)) .any(handler_404)) From a2d0d44b4248769c30fff79ef70f42e3174f4023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Wed, 14 Feb 2024 19:16:05 +0100 Subject: [PATCH 70/81] Remove unused allow's (#6760) These allow's became redundant some time ago so remove them, or address them if addressing is very simple. --- control_plane/attachment_service/src/persistence.rs | 2 -- libs/metrics/src/lib.rs | 1 - libs/postgres_ffi/src/lib.rs | 2 +- libs/remote_storage/src/local_fs.rs | 1 - libs/utils/benches/benchmarks.rs | 2 -- pageserver/src/deletion_queue.rs | 1 - pageserver/src/disk_usage_eviction_task.rs | 6 ------ pageserver/src/task_mgr.rs | 5 ----- pageserver/src/tenant.rs | 1 - pageserver/src/tenant/disk_btree.rs | 1 - pageserver/src/tenant/timeline/eviction_task.rs | 2 +- s3_scrubber/src/cloud_admin_api.rs | 6 +----- 12 files changed, 3 insertions(+), 27 deletions(-) diff --git a/control_plane/attachment_service/src/persistence.rs b/control_plane/attachment_service/src/persistence.rs index 457dc43232..5b3b032bc9 100644 --- a/control_plane/attachment_service/src/persistence.rs +++ b/control_plane/attachment_service/src/persistence.rs @@ -381,7 +381,6 @@ impl Persistence { // // We create the child shards here, so that they will be available for increment_generation calls // if some pageserver holding a child shard needs to restart before the overall tenant split is complete. - #[allow(dead_code)] pub(crate) async fn begin_shard_split( &self, old_shard_count: ShardCount, @@ -449,7 +448,6 @@ impl Persistence { // When we finish shard splitting, we must atomically clean up the old shards // and insert the new shards, and clear the splitting marker. - #[allow(dead_code)] pub(crate) async fn complete_shard_split( &self, split_tenant_id: TenantId, diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index b57fd9f33b..18786106d1 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -115,7 +115,6 @@ pub fn set_build_info_metric(revision: &str, build_tag: &str) { // performed by the process. // We know the size of the block, so we can determine the I/O bytes out of it. // The value might be not 100% exact, but should be fine for Prometheus metrics in this case. -#[allow(clippy::unnecessary_cast)] fn update_rusage_metrics() { let rusage_stats = get_rusage_stats(); diff --git a/libs/postgres_ffi/src/lib.rs b/libs/postgres_ffi/src/lib.rs index d10ebfe277..aa6845b9b1 100644 --- a/libs/postgres_ffi/src/lib.rs +++ b/libs/postgres_ffi/src/lib.rs @@ -3,7 +3,7 @@ #![allow(non_snake_case)] // bindgen creates some unsafe code with no doc comments. #![allow(clippy::missing_safety_doc)] -// noted at 1.63 that in many cases there's a u32 -> u32 transmutes in bindgen code. +// noted at 1.63 that in many cases there's u32 -> u32 transmutes in bindgen code. #![allow(clippy::useless_transmute)] // modules included with the postgres_ffi macro depend on the types of the specific version's // types, and trigger a too eager lint. diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index f53ba9db07..e88111e8e2 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -435,7 +435,6 @@ impl RemoteStorage for LocalFs { Ok(()) } - #[allow(clippy::diverging_sub_expression)] async fn time_travel_recover( &self, _prefix: Option<&RemotePath>, diff --git a/libs/utils/benches/benchmarks.rs b/libs/utils/benches/benchmarks.rs index 98d839ca55..44eb36387c 100644 --- a/libs/utils/benches/benchmarks.rs +++ b/libs/utils/benches/benchmarks.rs @@ -1,5 +1,3 @@ -#![allow(unused)] - use criterion::{criterion_group, criterion_main, Criterion}; use utils::id; diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index e0c40ea1b0..f8f2866a3b 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -835,7 +835,6 @@ mod test { } impl ControlPlaneGenerationsApi for MockControlPlane { - #[allow(clippy::diverging_sub_expression)] // False positive via async_trait async fn re_attach(&self) -> Result, RetryForeverError> { unimplemented!() } diff --git a/pageserver/src/disk_usage_eviction_task.rs b/pageserver/src/disk_usage_eviction_task.rs index d5f5a20683..b1c6f35704 100644 --- a/pageserver/src/disk_usage_eviction_task.rs +++ b/pageserver/src/disk_usage_eviction_task.rs @@ -351,7 +351,6 @@ pub enum IterationOutcome { Finished(IterationOutcomeFinished), } -#[allow(dead_code)] #[derive(Debug, Serialize)] pub struct IterationOutcomeFinished { /// The actual usage observed before we started the iteration. @@ -366,7 +365,6 @@ pub struct IterationOutcomeFinished { } #[derive(Debug, Serialize)] -#[allow(dead_code)] struct AssumedUsage { /// The expected value for `after`, after phase 2. projected_after: U, @@ -374,14 +372,12 @@ struct AssumedUsage { failed: LayerCount, } -#[allow(dead_code)] #[derive(Debug, Serialize)] struct PlannedUsage { respecting_tenant_min_resident_size: U, fallback_to_global_lru: Option, } -#[allow(dead_code)] #[derive(Debug, Default, Serialize)] struct LayerCount { file_sizes: u64, @@ -565,7 +561,6 @@ pub(crate) struct EvictionSecondaryLayer { #[derive(Clone)] pub(crate) enum EvictionLayer { Attached(Layer), - #[allow(dead_code)] Secondary(EvictionSecondaryLayer), } @@ -1105,7 +1100,6 @@ mod filesystem_level_usage { use super::DiskUsageEvictionTaskConfig; #[derive(Debug, Clone, Copy)] - #[allow(dead_code)] pub struct Usage<'a> { config: &'a DiskUsageEvictionTaskConfig, diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 3cec5fa850..6317b0a7ae 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -30,10 +30,6 @@ //! only a single tenant or timeline. //! -// Clippy 1.60 incorrectly complains about the tokio::task_local!() macro. -// Silence it. See https://github.com/rust-lang/rust-clippy/issues/9224. -#![allow(clippy::declare_interior_mutable_const)] - use std::collections::HashMap; use std::fmt; use std::future::Future; @@ -312,7 +308,6 @@ struct MutableTaskState { } struct PageServerTask { - #[allow(dead_code)] // unused currently task_id: PageserverTaskId, kind: TaskKind, diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 44a446d697..dc9b8247a5 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -4360,7 +4360,6 @@ mod tests { ctx: &RequestContext, ) -> anyhow::Result<()> { let mut lsn = start_lsn; - #[allow(non_snake_case)] { let writer = tline.writer().await; // Create a relation on the timeline diff --git a/pageserver/src/tenant/disk_btree.rs b/pageserver/src/tenant/disk_btree.rs index 06a04bf536..9f104aff86 100644 --- a/pageserver/src/tenant/disk_btree.rs +++ b/pageserver/src/tenant/disk_btree.rs @@ -36,7 +36,6 @@ use crate::{ pub const VALUE_SZ: usize = 5; pub const MAX_VALUE: u64 = 0x007f_ffff_ffff; -#[allow(dead_code)] pub const PAGE_SZ: usize = 8192; #[derive(Clone, Copy, Debug)] diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index d87f78e35f..33ba234a63 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -196,13 +196,13 @@ impl Timeline { ControlFlow::Continue(()) => (), } - #[allow(dead_code)] #[derive(Debug, Default)] struct EvictionStats { candidates: usize, evicted: usize, errors: usize, not_evictable: usize, + #[allow(dead_code)] skipped_for_shutdown: usize, } diff --git a/s3_scrubber/src/cloud_admin_api.rs b/s3_scrubber/src/cloud_admin_api.rs index 151421c84f..45cac23690 100644 --- a/s3_scrubber/src/cloud_admin_api.rs +++ b/s3_scrubber/src/cloud_admin_api.rs @@ -1,11 +1,7 @@ -#![allow(unused)] - -use std::str::FromStr; use std::time::Duration; use chrono::{DateTime, Utc}; use hex::FromHex; -use pageserver::tenant::Tenant; use reqwest::{header, Client, StatusCode, Url}; use serde::Deserialize; use tokio::sync::Semaphore; @@ -290,7 +286,7 @@ impl CloudAdminApiClient { tokio::time::sleep(Duration::from_millis(500)).await; continue; } - status => { + _status => { return Err(Error::new( "List active projects".to_string(), ErrorKind::ResponseStatus(response.status()), From c7538a2c20178ecd32662de3200cfe9fff19e8a3 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Wed, 14 Feb 2024 19:43:52 +0100 Subject: [PATCH 71/81] Proxy: remove fail fast logic to connect to compute (#6759) ## Problem Flaky tests ## Summary of changes Remove failfast logic --- proxy/src/proxy/connect_compute.rs | 35 ++++++++++++++--------------- proxy/src/proxy/tests.rs | 36 ------------------------------ 2 files changed, 17 insertions(+), 54 deletions(-) diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 6e57caf998..c76e2ff6d9 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -122,25 +122,24 @@ where error!(error = ?err, "could not connect to compute node"); - let node_info = - if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() { - // If the error is Postgres, that means that we managed to connect to the compute node, but there was an error. - // Do not need to retrieve a new node_info, just return the old one. - if !err.should_retry(num_retries) { - return Err(err.into()); - } - node_info - } else { - // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node - info!("compute node's state has likely changed; requesting a wake-up"); - ctx.latency_timer.cache_miss(); - let old_node_info = invalidate_cache(node_info); - let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; - node_info.reuse_settings(old_node_info); + let node_info = if !node_info.cached() { + // If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry. + // Do not need to retrieve a new node_info, just return the old one. + if !err.should_retry(num_retries) { + return Err(err.into()); + } + node_info + } else { + // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node + info!("compute node's state has likely changed; requesting a wake-up"); + ctx.latency_timer.cache_miss(); + let old_node_info = invalidate_cache(node_info); + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + node_info.reuse_settings(old_node_info); - mechanism.update_connect_config(&mut node_info.config); - node_info - }; + mechanism.update_connect_config(&mut node_info.config); + node_info + }; // now that we have a new node, try connect to it repeatedly. // this can error for a few reasons, for instance: diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index efbd661bbf..1a01f32339 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -375,8 +375,6 @@ enum ConnectAction { Connect, Retry, Fail, - RetryPg, - FailPg, } #[derive(Clone)] @@ -466,14 +464,6 @@ impl ConnectMechanism for TestConnectMechanism { retryable: false, kind: ErrorKind::Compute, }), - ConnectAction::FailPg => Err(TestConnectError { - retryable: false, - kind: ErrorKind::Postgres, - }), - ConnectAction::RetryPg => Err(TestConnectError { - retryable: true, - kind: ErrorKind::Postgres, - }), x => panic!("expecting action {:?}, connect is called instead", x), } } @@ -572,32 +562,6 @@ async fn connect_to_compute_retry() { mechanism.verify(); } -#[tokio::test] -async fn connect_to_compute_retry_pg() { - let _ = env_logger::try_init(); - use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]); - let user_info = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, &user_info, false) - .await - .unwrap(); - mechanism.verify(); -} - -#[tokio::test] -async fn connect_to_compute_fail_pg() { - let _ = env_logger::try_init(); - use ConnectAction::*; - let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]); - let user_info = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, &user_info, false) - .await - .unwrap_err(); - mechanism.verify(); -} - /// Test that we don't retry if the error is not retryable. #[tokio::test] async fn connect_to_compute_non_retry_1() { From fff2468aa2780edb3941f9851e19ee0bfb1fafd1 Mon Sep 17 00:00:00 2001 From: Shayan Hosseini Date: Wed, 14 Feb 2024 10:45:05 -0800 Subject: [PATCH 72/81] Add resource consume test funcs (#6747) ## Problem Building on #5875 to add handy test functions for autoscaling. Resolves #5609 ## Summary of changes This PR makes the following changes to #5875: - Enable `neon_test_utils` extension in the compute node docker image, so we could use it in the e2e tests (as discussed with @kelvich). - Removed test functions related to disk as we don't use them for autoscaling. - Fix the warning with printf-ing unsigned long variables. --------- Co-authored-by: Heikki Linnakangas --- Dockerfile.compute-node | 4 + pgxn/neon_test_utils/neon_test_utils--1.0.sql | 18 +++ pgxn/neon_test_utils/neon_test_utils.control | 1 + pgxn/neon_test_utils/neontest.c | 118 ++++++++++++++++++ .../sql_regress/expected/neon-test-utils.out | 28 +++++ test_runner/sql_regress/parallel_schedule | 1 + .../sql_regress/sql/neon-test-utils.sql | 11 ++ 7 files changed, 181 insertions(+) create mode 100644 test_runner/sql_regress/expected/neon-test-utils.out create mode 100644 test_runner/sql_regress/sql/neon-test-utils.sql diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index cc7a110008..4eb6dc91c0 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -820,6 +820,10 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \ PG_CONFIG=/usr/local/pgsql/bin/pg_config \ -C pgxn/neon_utils \ -s install && \ + make -j $(getconf _NPROCESSORS_ONLN) \ + PG_CONFIG=/usr/local/pgsql/bin/pg_config \ + -C pgxn/neon_test_utils \ + -s install && \ make -j $(getconf _NPROCESSORS_ONLN) \ PG_CONFIG=/usr/local/pgsql/bin/pg_config \ -C pgxn/neon_rmgr \ diff --git a/pgxn/neon_test_utils/neon_test_utils--1.0.sql b/pgxn/neon_test_utils/neon_test_utils--1.0.sql index 402981a9a6..23340e352e 100644 --- a/pgxn/neon_test_utils/neon_test_utils--1.0.sql +++ b/pgxn/neon_test_utils/neon_test_utils--1.0.sql @@ -7,6 +7,24 @@ AS 'MODULE_PATHNAME', 'test_consume_xids' LANGUAGE C STRICT PARALLEL UNSAFE; +CREATE FUNCTION test_consume_cpu(seconds int) +RETURNS VOID +AS 'MODULE_PATHNAME', 'test_consume_cpu' +LANGUAGE C STRICT +PARALLEL UNSAFE; + +CREATE FUNCTION test_consume_memory(megabytes int) +RETURNS VOID +AS 'MODULE_PATHNAME', 'test_consume_memory' +LANGUAGE C STRICT +PARALLEL UNSAFE; + +CREATE FUNCTION test_release_memory(megabytes int DEFAULT NULL) +RETURNS VOID +AS 'MODULE_PATHNAME', 'test_release_memory' +LANGUAGE C +PARALLEL UNSAFE; + CREATE FUNCTION clear_buffer_cache() RETURNS VOID AS 'MODULE_PATHNAME', 'clear_buffer_cache' diff --git a/pgxn/neon_test_utils/neon_test_utils.control b/pgxn/neon_test_utils/neon_test_utils.control index 94e6720503..5219571f11 100644 --- a/pgxn/neon_test_utils/neon_test_utils.control +++ b/pgxn/neon_test_utils/neon_test_utils.control @@ -3,3 +3,4 @@ comment = 'helpers for neon testing and debugging' default_version = '1.0' module_pathname = '$libdir/neon_test_utils' relocatable = true +trusted = true diff --git a/pgxn/neon_test_utils/neontest.c b/pgxn/neon_test_utils/neontest.c index aa644efd40..7c618848e2 100644 --- a/pgxn/neon_test_utils/neontest.c +++ b/pgxn/neon_test_utils/neontest.c @@ -21,10 +21,12 @@ #include "miscadmin.h" #include "storage/buf_internals.h" #include "storage/bufmgr.h" +#include "storage/fd.h" #include "utils/builtins.h" #include "utils/pg_lsn.h" #include "utils/rel.h" #include "utils/varlena.h" +#include "utils/wait_event.h" #include "../neon/pagestore_client.h" PG_MODULE_MAGIC; @@ -32,6 +34,9 @@ PG_MODULE_MAGIC; extern void _PG_init(void); PG_FUNCTION_INFO_V1(test_consume_xids); +PG_FUNCTION_INFO_V1(test_consume_cpu); +PG_FUNCTION_INFO_V1(test_consume_memory); +PG_FUNCTION_INFO_V1(test_release_memory); PG_FUNCTION_INFO_V1(clear_buffer_cache); PG_FUNCTION_INFO_V1(get_raw_page_at_lsn); PG_FUNCTION_INFO_V1(get_raw_page_at_lsn_ex); @@ -97,6 +102,119 @@ test_consume_xids(PG_FUNCTION_ARGS) PG_RETURN_VOID(); } + +/* + * test_consume_cpu(seconds int). Keeps one CPU busy for the given number of seconds. + */ +Datum +test_consume_cpu(PG_FUNCTION_ARGS) +{ + int32 seconds = PG_GETARG_INT32(0); + TimestampTz start; + uint64 total_iterations = 0; + + start = GetCurrentTimestamp(); + + for (;;) + { + TimestampTz elapsed; + + elapsed = GetCurrentTimestamp() - start; + if (elapsed > (TimestampTz) seconds * USECS_PER_SEC) + break; + + /* keep spinning */ + for (int i = 0; i < 1000000; i++) + total_iterations++; + elog(DEBUG2, "test_consume_cpu(): %lu iterations in total", total_iterations); + + CHECK_FOR_INTERRUPTS(); + } + + PG_RETURN_VOID(); +} + +static MemoryContext consume_cxt = NULL; +static slist_head consumed_memory_chunks; +static int64 num_memory_chunks; + +/* + * test_consume_memory(megabytes int). + * + * Consume given amount of memory. The allocation is made in TopMemoryContext, + * so it outlives the function, until you call test_release_memory to + * explicitly release it, or close the session. + */ +Datum +test_consume_memory(PG_FUNCTION_ARGS) +{ + int32 megabytes = PG_GETARG_INT32(0); + + /* + * Consume the memory in a new memory context, so that it's convenient to + * release and to display it separately in a possible memory context dump. + */ + if (consume_cxt == NULL) + consume_cxt = AllocSetContextCreate(TopMemoryContext, + "test_consume_memory", + ALLOCSET_DEFAULT_SIZES); + + for (int32 i = 0; i < megabytes; i++) + { + char *p; + + p = MemoryContextAllocZero(consume_cxt, 1024 * 1024); + + /* touch the memory, so that it's really allocated by the kernel */ + for (int j = 0; j < 1024 * 1024; j += 1024) + p[j] = j % 0xFF; + + slist_push_head(&consumed_memory_chunks, (slist_node *) p); + num_memory_chunks++; + } + + PG_RETURN_VOID(); +} + +/* + * test_release_memory(megabytes int). NULL releases all + */ +Datum +test_release_memory(PG_FUNCTION_ARGS) +{ + TimestampTz start; + + if (PG_ARGISNULL(0)) + { + if (consume_cxt) + { + MemoryContextDelete(consume_cxt); + consume_cxt = NULL; + num_memory_chunks = 0; + } + } + else + { + int32 chunks_to_release = PG_GETARG_INT32(0); + + if (chunks_to_release > num_memory_chunks) + { + elog(WARNING, "only %lu MB is consumed, releasing it all", num_memory_chunks); + chunks_to_release = num_memory_chunks; + } + + for (int32 i = 0; i < chunks_to_release; i++) + { + slist_node *chunk = slist_pop_head_node(&consumed_memory_chunks); + + pfree(chunk); + num_memory_chunks--; + } + } + + PG_RETURN_VOID(); +} + /* * Flush the buffer cache, evicting all pages that are not currently pinned. */ diff --git a/test_runner/sql_regress/expected/neon-test-utils.out b/test_runner/sql_regress/expected/neon-test-utils.out new file mode 100644 index 0000000000..7d1634a6b8 --- /dev/null +++ b/test_runner/sql_regress/expected/neon-test-utils.out @@ -0,0 +1,28 @@ +-- Test the test utils in pgxn/neon_test_utils. We don't test that +-- these actually consume resources like they should - that would be +-- tricky - but at least we check that they don't crash. +CREATE EXTENSION neon_test_utils; +select test_consume_cpu(1); + test_consume_cpu +------------------ + +(1 row) + +select test_consume_memory(20); -- Allocate 20 MB + test_consume_memory +--------------------- + +(1 row) + +select test_release_memory(5); -- Release 5 MB + test_release_memory +--------------------- + +(1 row) + +select test_release_memory(); -- Release the remaining 15 MB + test_release_memory +--------------------- + +(1 row) + diff --git a/test_runner/sql_regress/parallel_schedule b/test_runner/sql_regress/parallel_schedule index 569c7b5066..d9508d1c90 100644 --- a/test_runner/sql_regress/parallel_schedule +++ b/test_runner/sql_regress/parallel_schedule @@ -7,4 +7,5 @@ test: neon-cid test: neon-rel-truncate test: neon-clog +test: neon-test-utils test: neon-vacuum-full diff --git a/test_runner/sql_regress/sql/neon-test-utils.sql b/test_runner/sql_regress/sql/neon-test-utils.sql new file mode 100644 index 0000000000..c5ca6c624b --- /dev/null +++ b/test_runner/sql_regress/sql/neon-test-utils.sql @@ -0,0 +1,11 @@ +-- Test the test utils in pgxn/neon_test_utils. We don't test that +-- these actually consume resources like they should - that would be +-- tricky - but at least we check that they don't crash. + +CREATE EXTENSION neon_test_utils; + +select test_consume_cpu(1); + +select test_consume_memory(20); -- Allocate 20 MB +select test_release_memory(5); -- Release 5 MB +select test_release_memory(); -- Release the remaining 15 MB From 024372a3db071c945cbdd7f4cc1b759e56386534 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 14 Feb 2024 20:17:12 +0100 Subject: [PATCH 73/81] Revert "refactor(VirtualFile::crashsafe_overwrite): avoid Handle::block_on in callers" (#6765) Reverts neondatabase/neon#6731 On high tenant count Pageservers in staging, memory and CPU usage shoots to 100% with this change. (NB: staging currently has tokio-epoll-uring enabled) Will analyze tomorrow. https://neondb.slack.com/archives/C03H1K0PGKH/p1707933875639379?thread_ts=1707929541.125329&cid=C03H1K0PGKH --- libs/utils/src/crashsafe.rs | 44 +----------- pageserver/src/deletion_queue.rs | 5 +- pageserver/src/tenant.rs | 33 ++++++--- pageserver/src/tenant/metadata.rs | 2 +- pageserver/src/tenant/secondary/downloader.rs | 11 ++- pageserver/src/virtual_file.rs | 72 +++++++++++-------- 6 files changed, 78 insertions(+), 89 deletions(-) diff --git a/libs/utils/src/crashsafe.rs b/libs/utils/src/crashsafe.rs index 756b19138c..1c72e9cae9 100644 --- a/libs/utils/src/crashsafe.rs +++ b/libs/utils/src/crashsafe.rs @@ -1,7 +1,7 @@ use std::{ borrow::Cow, fs::{self, File}, - io::{self, Write}, + io, }; use camino::{Utf8Path, Utf8PathBuf}; @@ -161,48 +161,6 @@ pub async fn durable_rename( Ok(()) } -/// Writes a file to the specified `final_path` in a crash safe fasion, using [`std::fs`]. -/// -/// The file is first written to the specified `tmp_path`, and in a second -/// step, the `tmp_path` is renamed to the `final_path`. Intermediary fsync -/// and atomic rename guarantee that, if we crash at any point, there will never -/// be a partially written file at `final_path` (but maybe at `tmp_path`). -/// -/// Callers are responsible for serializing calls of this function for a given `final_path`. -/// If they don't, there may be an error due to conflicting `tmp_path`, or there will -/// be no error and the content of `final_path` will be the "winner" caller's `content`. -/// I.e., the atomticity guarantees still hold. -pub fn overwrite( - final_path: &Utf8Path, - tmp_path: &Utf8Path, - content: &[u8], -) -> std::io::Result<()> { - let Some(final_path_parent) = final_path.parent() else { - return Err(std::io::Error::from_raw_os_error( - nix::errno::Errno::EINVAL as i32, - )); - }; - std::fs::remove_file(tmp_path).or_else(crate::fs_ext::ignore_not_found)?; - let mut file = std::fs::OpenOptions::new() - .write(true) - // Use `create_new` so that, if we race with ourselves or something else, - // we bail out instead of causing damage. - .create_new(true) - .open(tmp_path)?; - file.write_all(content)?; - file.sync_all()?; - drop(file); // don't keep the fd open for longer than we have to - - std::fs::rename(tmp_path, final_path)?; - - let final_parent_dirfd = std::fs::OpenOptions::new() - .read(true) - .open(final_path_parent)?; - - final_parent_dirfd.sync_all()?; - Ok(()) -} - #[cfg(test)] mod tests { diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index f8f2866a3b..81938b14b3 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -234,7 +234,7 @@ impl DeletionHeader { let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?; let header_path = conf.deletion_header_path(); let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX); - VirtualFile::crashsafe_overwrite(header_path, temp_path, header_bytes) + VirtualFile::crashsafe_overwrite(&header_path, &temp_path, header_bytes) .await .maybe_fatal_err("save deletion header")?; @@ -325,8 +325,7 @@ impl DeletionList { let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX); let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list"); - - VirtualFile::crashsafe_overwrite(path, temp_path, bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, bytes) .await .maybe_fatal_err("save deletion list") .map_err(Into::into) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index dc9b8247a5..88f4ae7086 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -28,6 +28,7 @@ use remote_storage::GenericRemoteStorage; use std::fmt; use storage_broker::BrokerClientChannel; use tokio::io::BufReader; +use tokio::runtime::Handle; use tokio::sync::watch; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -2877,10 +2878,17 @@ impl Tenant { let tenant_shard_id = *tenant_shard_id; let config_path = config_path.to_owned(); - let conf_content = conf_content.into_bytes(); - VirtualFile::crashsafe_overwrite(config_path.clone(), temp_path, conf_content) - .await - .with_context(|| format!("write tenant {tenant_shard_id} config to {config_path}"))?; + tokio::task::spawn_blocking(move || { + Handle::current().block_on(async move { + let conf_content = conf_content.into_bytes(); + VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content) + .await + .with_context(|| { + format!("write tenant {tenant_shard_id} config to {config_path}") + }) + }) + }) + .await??; Ok(()) } @@ -2907,12 +2915,17 @@ impl Tenant { let tenant_shard_id = *tenant_shard_id; let target_config_path = target_config_path.to_owned(); - let conf_content = conf_content.into_bytes(); - VirtualFile::crashsafe_overwrite(target_config_path.clone(), temp_path, conf_content) - .await - .with_context(|| { - format!("write tenant {tenant_shard_id} config to {target_config_path}") - })?; + tokio::task::spawn_blocking(move || { + Handle::current().block_on(async move { + let conf_content = conf_content.into_bytes(); + VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content) + .await + .with_context(|| { + format!("write tenant {tenant_shard_id} config to {target_config_path}") + }) + }) + }) + .await??; Ok(()) } diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index 233acfd431..dcbe781f90 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -279,7 +279,7 @@ pub async fn save_metadata( let path = conf.metadata_path(tenant_shard_id, timeline_id); let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX); let metadata_bytes = data.to_bytes().context("serialize metadata")?; - VirtualFile::crashsafe_overwrite(path, temp_path, metadata_bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, metadata_bytes) .await .context("write metadata")?; Ok(()) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index c8288acc20..c23416a7f0 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -484,9 +484,14 @@ impl<'a> TenantDownloader<'a> { let temp_path = path_with_suffix_extension(&heatmap_path, TEMP_FILE_SUFFIX); let context_msg = format!("write tenant {tenant_shard_id} heatmap to {heatmap_path}"); let heatmap_path_bg = heatmap_path.clone(); - VirtualFile::crashsafe_overwrite(heatmap_path_bg, temp_path, heatmap_bytes) - .await - .maybe_fatal_err(&context_msg)?; + tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, heatmap_bytes).await + }) + }) + .await + .expect("Blocking task is never aborted") + .maybe_fatal_err(&context_msg)?; tracing::debug!("Wrote local heatmap to {}", heatmap_path); diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 858fc0ef64..45c3e19cfc 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -19,13 +19,14 @@ use once_cell::sync::OnceCell; use pageserver_api::shard::TenantShardId; use std::fs::{self, File}; use std::io::{Error, ErrorKind, Seek, SeekFrom}; -use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice}; +use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::fs::FileExt; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use tokio::time::Instant; +use utils::fs_ext; pub use pageserver_api::models::virtual_file as api; pub(crate) mod io_engine; @@ -403,34 +404,47 @@ impl VirtualFile { Ok(vfile) } - /// Async version of [`::utils::crashsafe::overwrite`]. + /// Writes a file to the specified `final_path` in a crash safe fasion /// - /// # NB: - /// - /// Doesn't actually use the [`VirtualFile`] file descriptor cache, but, - /// it did at an earlier time. - /// And it will use this module's [`io_engine`] in the near future, so, leaving it here. - pub async fn crashsafe_overwrite + Send, Buf: IoBuf + Send>( - final_path: Utf8PathBuf, - tmp_path: Utf8PathBuf, + /// The file is first written to the specified tmp_path, and in a second + /// step, the tmp path is renamed to the final path. As renames are + /// atomic, a crash during the write operation will never leave behind a + /// partially written file. + pub async fn crashsafe_overwrite( + final_path: &Utf8Path, + tmp_path: &Utf8Path, content: B, ) -> std::io::Result<()> { - // TODO: use tokio_epoll_uring if configured as `io_engine`. - // See https://github.com/neondatabase/neon/issues/6663 - - tokio::task::spawn_blocking(move || { - let slice_storage; - let content_len = content.bytes_init(); - let content = if content.bytes_init() > 0 { - slice_storage = Some(content.slice(0..content_len)); - slice_storage.as_deref().expect("just set it to Some()") - } else { - &[] - }; - utils::crashsafe::overwrite(&final_path, &tmp_path, content) - }) - .await - .expect("blocking task is never aborted") + let Some(final_path_parent) = final_path.parent() else { + return Err(std::io::Error::from_raw_os_error( + nix::errno::Errno::EINVAL as i32, + )); + }; + std::fs::remove_file(tmp_path).or_else(fs_ext::ignore_not_found)?; + let mut file = Self::open_with_options( + tmp_path, + OpenOptions::new() + .write(true) + // Use `create_new` so that, if we race with ourselves or something else, + // we bail out instead of causing damage. + .create_new(true), + ) + .await?; + let (_content, res) = file.write_all(content).await; + res?; + file.sync_all().await?; + drop(file); // before the rename, that's important! + // renames are atomic + std::fs::rename(tmp_path, final_path)?; + // Only open final path parent dirfd now, so that this operation only + // ever holds one VirtualFile fd at a time. That's important because + // the current `find_victim_slot` impl might pick the same slot for both + // VirtualFile., and it eventually does a blocking write lock instead of + // try_lock. + let final_parent_dirfd = + Self::open_with_options(final_path_parent, OpenOptions::new().read(true)).await?; + final_parent_dirfd.sync_all().await?; + Ok(()) } /// Call File::sync_all() on the underlying File. @@ -1323,7 +1337,7 @@ mod tests { let path = testdir.join("myfile"); let tmp_path = testdir.join("myfile.tmp"); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1332,7 +1346,7 @@ mod tests { assert!(!tmp_path.exists()); drop(file); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec()) + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1354,7 +1368,7 @@ mod tests { std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap(); assert!(tmp_path.exists()); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); From 80854b98ff0dad7b385c972523ac03352d10a938 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Thu, 15 Feb 2024 01:24:07 +0200 Subject: [PATCH 74/81] move timeouts and cancellation handling to remote_storage (#6697) Cancellation and timeouts are handled at remote_storage callsites, if they are. However they should always be handled, because we've had transient problems with remote storage connections. - Add cancellation token to the `trait RemoteStorage` methods - For `download*`, `list*` methods there is `DownloadError::{Cancelled,Timeout}` - For the rest now using `anyhow::Error`, it will have root cause `remote_storage::TimeoutOrCancel::{Cancel,Timeout}` - Both types have `::is_permanent` equivalent which should be passed to `backoff::retry` - New generic RemoteStorageConfig option `timeout`, defaults to 120s - Start counting timeouts only after acquiring concurrency limiter permit - Cancellable permit acquiring - Download stream timeout or cancellation is communicated via an `std::io::Error` - Exit backoff::retry by marking cancellation errors permanent Fixes: #6096 Closes: #4781 Co-authored-by: arpad-m --- Cargo.lock | 2 + libs/remote_storage/Cargo.toml | 2 + libs/remote_storage/src/azure_blob.rs | 425 +++++++++++------- libs/remote_storage/src/error.rs | 181 ++++++++ libs/remote_storage/src/lib.rs | 329 ++++++++------ libs/remote_storage/src/local_fs.rs | 420 ++++++++++++----- libs/remote_storage/src/s3_bucket.rs | 273 +++++++---- libs/remote_storage/src/simulate_failures.rs | 55 ++- libs/remote_storage/src/support.rs | 136 ++++++ libs/remote_storage/tests/common/mod.rs | 21 +- libs/remote_storage/tests/common/tests.rs | 72 ++- libs/remote_storage/tests/test_real_azure.rs | 14 +- libs/remote_storage/tests/test_real_s3.rs | 215 ++++++++- pageserver/src/config.rs | 2 + pageserver/src/deletion_queue.rs | 12 +- pageserver/src/deletion_queue/deleter.rs | 7 +- pageserver/src/tenant.rs | 8 +- pageserver/src/tenant/delete.rs | 14 +- .../src/tenant/remote_timeline_client.rs | 55 +-- .../tenant/remote_timeline_client/download.rs | 98 ++-- .../tenant/remote_timeline_client/upload.rs | 35 +- pageserver/src/tenant/secondary/downloader.rs | 5 +- .../src/tenant/secondary/heatmap_uploader.rs | 11 +- proxy/src/context/parquet.rs | 17 +- safekeeper/src/wal_backup.rs | 29 +- 25 files changed, 1712 insertions(+), 726 deletions(-) create mode 100644 libs/remote_storage/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index 45a313a72b..74cd2c8d2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4436,6 +4436,7 @@ dependencies = [ "futures", "futures-util", "http-types", + "humantime", "hyper", "itertools", "metrics", @@ -4447,6 +4448,7 @@ dependencies = [ "serde_json", "test-context", "tokio", + "tokio-stream", "tokio-util", "toml_edit", "tracing", diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 2cc59a947b..15f3cd3b80 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -15,11 +15,13 @@ aws-sdk-s3.workspace = true aws-credential-types.workspace = true bytes.workspace = true camino.workspace = true +humantime.workspace = true hyper = { workspace = true, features = ["stream"] } futures.workspace = true serde.workspace = true serde_json.workspace = true tokio = { workspace = true, features = ["sync", "fs", "io-util"] } +tokio-stream.workspace = true tokio-util = { workspace = true, features = ["compat"] } toml_edit.workspace = true tracing.workspace = true diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index df6d45dde1..12ec680cb6 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -22,16 +22,15 @@ use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerCl use bytes::Bytes; use futures::stream::Stream; use futures_util::StreamExt; +use futures_util::TryStreamExt; use http_types::{StatusCode, Url}; -use tokio::time::Instant; use tokio_util::sync::CancellationToken; use tracing::debug; -use crate::s3_bucket::RequestKind; -use crate::TimeTravelError; use crate::{ - AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, - RemoteStorage, StorageMetadata, + error::Cancelled, s3_bucket::RequestKind, AzureConfig, ConcurrencyLimiter, Download, + DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, StorageMetadata, + TimeTravelError, TimeoutOrCancel, }; pub struct AzureBlobStorage { @@ -39,10 +38,12 @@ pub struct AzureBlobStorage { prefix_in_container: Option, max_keys_per_list_response: Option, concurrency_limiter: ConcurrencyLimiter, + // Per-request timeout. Accessible for tests. + pub timeout: Duration, } impl AzureBlobStorage { - pub fn new(azure_config: &AzureConfig) -> Result { + pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result { debug!( "Creating azure remote storage for azure container {}", azure_config.container_name @@ -79,6 +80,7 @@ impl AzureBlobStorage { prefix_in_container: azure_config.prefix_in_container.to_owned(), max_keys_per_list_response, concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()), + timeout, }) } @@ -121,8 +123,11 @@ impl AzureBlobStorage { async fn download_for_builder( &self, builder: GetBlobBuilder, + cancel: &CancellationToken, ) -> Result { - let mut response = builder.into_stream(); + let kind = RequestKind::Get; + + let _permit = self.permit(kind, cancel).await?; let mut etag = None; let mut last_modified = None; @@ -130,39 +135,70 @@ impl AzureBlobStorage { // TODO give proper streaming response instead of buffering into RAM // https://github.com/neondatabase/neon/issues/5563 - let mut bufs = Vec::new(); - while let Some(part) = response.next().await { - let part = part.map_err(to_download_error)?; - let etag_str: &str = part.blob.properties.etag.as_ref(); - if etag.is_none() { - etag = Some(etag.unwrap_or_else(|| etag_str.to_owned())); + let download = async { + let response = builder + // convert to concrete Pageable + .into_stream() + // convert to TryStream + .into_stream() + .map_err(to_download_error); + + // apply per request timeout + let response = tokio_stream::StreamExt::timeout(response, self.timeout); + + // flatten + let response = response.map(|res| match res { + Ok(res) => res, + Err(_elapsed) => Err(DownloadError::Timeout), + }); + + let mut response = std::pin::pin!(response); + + let mut bufs = Vec::new(); + while let Some(part) = response.next().await { + let part = part?; + let etag_str: &str = part.blob.properties.etag.as_ref(); + if etag.is_none() { + etag = Some(etag.unwrap_or_else(|| etag_str.to_owned())); + } + if last_modified.is_none() { + last_modified = Some(part.blob.properties.last_modified.into()); + } + if let Some(blob_meta) = part.blob.metadata { + metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned()))); + } + let data = part + .data + .collect() + .await + .map_err(|e| DownloadError::Other(e.into()))?; + bufs.push(data); } - if last_modified.is_none() { - last_modified = Some(part.blob.properties.last_modified.into()); - } - if let Some(blob_meta) = part.blob.metadata { - metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned()))); - } - let data = part - .data - .collect() - .await - .map_err(|e| DownloadError::Other(e.into()))?; - bufs.push(data); + Ok(Download { + download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))), + etag, + last_modified, + metadata: Some(StorageMetadata(metadata)), + }) + }; + + tokio::select! { + bufs = download => bufs, + _ = cancel.cancelled() => Err(DownloadError::Cancelled), } - Ok(Download { - download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))), - etag, - last_modified, - metadata: Some(StorageMetadata(metadata)), - }) } - async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> { - self.concurrency_limiter - .acquire(kind) - .await - .expect("semaphore is never closed") + async fn permit( + &self, + kind: RequestKind, + cancel: &CancellationToken, + ) -> Result, Cancelled> { + let acquire = self.concurrency_limiter.acquire(kind); + + tokio::select! { + permit = acquire => Ok(permit.expect("never closed")), + _ = cancel.cancelled() => Err(Cancelled), + } } } @@ -192,66 +228,87 @@ impl RemoteStorage for AzureBlobStorage { prefix: Option<&RemotePath>, mode: ListingMode, max_keys: Option, + cancel: &CancellationToken, ) -> anyhow::Result { - // get the passed prefix or if it is not set use prefix_in_bucket value - let list_prefix = prefix - .map(|p| self.relative_path_to_name(p)) - .or_else(|| self.prefix_in_container.clone()) - .map(|mut p| { - // required to end with a separator - // otherwise request will return only the entry of a prefix - if matches!(mode, ListingMode::WithDelimiter) - && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) - { - p.push(REMOTE_STORAGE_PREFIX_SEPARATOR); - } - p + let _permit = self.permit(RequestKind::List, cancel).await?; + + let op = async { + // get the passed prefix or if it is not set use prefix_in_bucket value + let list_prefix = prefix + .map(|p| self.relative_path_to_name(p)) + .or_else(|| self.prefix_in_container.clone()) + .map(|mut p| { + // required to end with a separator + // otherwise request will return only the entry of a prefix + if matches!(mode, ListingMode::WithDelimiter) + && !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) + { + p.push(REMOTE_STORAGE_PREFIX_SEPARATOR); + } + p + }); + + let mut builder = self.client.list_blobs(); + + if let ListingMode::WithDelimiter = mode { + builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()); + } + + if let Some(prefix) = list_prefix { + builder = builder.prefix(Cow::from(prefix.to_owned())); + } + + if let Some(limit) = self.max_keys_per_list_response { + builder = builder.max_results(MaxResults::new(limit)); + } + + let response = builder.into_stream(); + let response = response.into_stream().map_err(to_download_error); + let response = tokio_stream::StreamExt::timeout(response, self.timeout); + let response = response.map(|res| match res { + Ok(res) => res, + Err(_elapsed) => Err(DownloadError::Timeout), }); - let mut builder = self.client.list_blobs(); + let mut response = std::pin::pin!(response); - if let ListingMode::WithDelimiter = mode { - builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()); - } + let mut res = Listing::default(); - if let Some(prefix) = list_prefix { - builder = builder.prefix(Cow::from(prefix.to_owned())); - } + let mut max_keys = max_keys.map(|mk| mk.get()); + while let Some(entry) = response.next().await { + let entry = entry?; + let prefix_iter = entry + .blobs + .prefixes() + .map(|prefix| self.name_to_relative_path(&prefix.name)); + res.prefixes.extend(prefix_iter); - if let Some(limit) = self.max_keys_per_list_response { - builder = builder.max_results(MaxResults::new(limit)); - } + let blob_iter = entry + .blobs + .blobs() + .map(|k| self.name_to_relative_path(&k.name)); - let mut response = builder.into_stream(); - let mut res = Listing::default(); - // NonZeroU32 doesn't support subtraction apparently - let mut max_keys = max_keys.map(|mk| mk.get()); - while let Some(l) = response.next().await { - let entry = l.map_err(to_download_error)?; - let prefix_iter = entry - .blobs - .prefixes() - .map(|prefix| self.name_to_relative_path(&prefix.name)); - res.prefixes.extend(prefix_iter); + for key in blob_iter { + res.keys.push(key); - let blob_iter = entry - .blobs - .blobs() - .map(|k| self.name_to_relative_path(&k.name)); - - for key in blob_iter { - res.keys.push(key); - if let Some(mut mk) = max_keys { - assert!(mk > 0); - mk -= 1; - if mk == 0 { - return Ok(res); // limit reached + if let Some(mut mk) = max_keys { + assert!(mk > 0); + mk -= 1; + if mk == 0 { + return Ok(res); // limit reached + } + max_keys = Some(mk); } - max_keys = Some(mk); } } + + Ok(res) + }; + + tokio::select! { + res = op => res, + _ = cancel.cancelled() => Err(DownloadError::Cancelled), } - Ok(res) } async fn upload( @@ -260,35 +317,52 @@ impl RemoteStorage for AzureBlobStorage { data_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()> { - let _permit = self.permit(RequestKind::Put).await; - let blob_client = self.client.blob_client(self.relative_path_to_name(to)); + let _permit = self.permit(RequestKind::Put, cancel).await?; - let from: Pin> + Send + Sync + 'static>> = - Box::pin(from); + let op = async { + let blob_client = self.client.blob_client(self.relative_path_to_name(to)); - let from = NonSeekableStream::new(from, data_size_bytes); + let from: Pin> + Send + Sync + 'static>> = + Box::pin(from); - let body = azure_core::Body::SeekableStream(Box::new(from)); + let from = NonSeekableStream::new(from, data_size_bytes); - let mut builder = blob_client.put_block_blob(body); + let body = azure_core::Body::SeekableStream(Box::new(from)); - if let Some(metadata) = metadata { - builder = builder.metadata(to_azure_metadata(metadata)); + let mut builder = blob_client.put_block_blob(body); + + if let Some(metadata) = metadata { + builder = builder.metadata(to_azure_metadata(metadata)); + } + + let fut = builder.into_future(); + let fut = tokio::time::timeout(self.timeout, fut); + + match fut.await { + Ok(Ok(_response)) => Ok(()), + Ok(Err(azure)) => Err(azure.into()), + Err(_timeout) => Err(TimeoutOrCancel::Cancel.into()), + } + }; + + tokio::select! { + res = op => res, + _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()), } - - let _response = builder.into_future().await?; - - Ok(()) } - async fn download(&self, from: &RemotePath) -> Result { - let _permit = self.permit(RequestKind::Get).await; + async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result { let blob_client = self.client.blob_client(self.relative_path_to_name(from)); let builder = blob_client.get(); - self.download_for_builder(builder).await + self.download_for_builder(builder, cancel).await } async fn download_byte_range( @@ -296,8 +370,8 @@ impl RemoteStorage for AzureBlobStorage { from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result { - let _permit = self.permit(RequestKind::Get).await; let blob_client = self.client.blob_client(self.relative_path_to_name(from)); let mut builder = blob_client.get(); @@ -309,82 +383,113 @@ impl RemoteStorage for AzureBlobStorage { }; builder = builder.range(range); - self.download_for_builder(builder).await + self.download_for_builder(builder, cancel).await } - async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> { - let _permit = self.permit(RequestKind::Delete).await; - let blob_client = self.client.blob_client(self.relative_path_to_name(path)); + async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> { + self.delete_objects(std::array::from_ref(path), cancel) + .await + } - let builder = blob_client.delete(); + async fn delete_objects<'a>( + &self, + paths: &'a [RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()> { + let _permit = self.permit(RequestKind::Delete, cancel).await?; - match builder.into_future().await { - Ok(_response) => Ok(()), - Err(e) => { - if let Some(http_err) = e.as_http_error() { - if http_err.status() == StatusCode::NotFound { - return Ok(()); + let op = async { + // TODO batch requests are also not supported by the SDK + // https://github.com/Azure/azure-sdk-for-rust/issues/1068 + // https://github.com/Azure/azure-sdk-for-rust/issues/1249 + for path in paths { + let blob_client = self.client.blob_client(self.relative_path_to_name(path)); + + let request = blob_client.delete().into_future(); + + let res = tokio::time::timeout(self.timeout, request).await; + + match res { + Ok(Ok(_response)) => continue, + Ok(Err(e)) => { + if let Some(http_err) = e.as_http_error() { + if http_err.status() == StatusCode::NotFound { + continue; + } + } + return Err(e.into()); } + Err(_elapsed) => return Err(TimeoutOrCancel::Timeout.into()), } - Err(anyhow::Error::new(e)) } + + Ok(()) + }; + + tokio::select! { + res = op => res, + _ = cancel.cancelled() => Err(TimeoutOrCancel::Cancel.into()), } } - async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> { - // Permit is already obtained by inner delete function + async fn copy( + &self, + from: &RemotePath, + to: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { + let _permit = self.permit(RequestKind::Copy, cancel).await?; - // TODO batch requests are also not supported by the SDK - // https://github.com/Azure/azure-sdk-for-rust/issues/1068 - // https://github.com/Azure/azure-sdk-for-rust/issues/1249 - for path in paths { - self.delete(path).await?; - } - Ok(()) - } + let timeout = tokio::time::sleep(self.timeout); - async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> { - let _permit = self.permit(RequestKind::Copy).await; - let blob_client = self.client.blob_client(self.relative_path_to_name(to)); + let mut copy_status = None; - let source_url = format!( - "{}/{}", - self.client.url()?, - self.relative_path_to_name(from) - ); - let builder = blob_client.copy(Url::from_str(&source_url)?); + let op = async { + let blob_client = self.client.blob_client(self.relative_path_to_name(to)); - let result = builder.into_future().await?; + let source_url = format!( + "{}/{}", + self.client.url()?, + self.relative_path_to_name(from) + ); - let mut copy_status = result.copy_status; - let start_time = Instant::now(); - const MAX_WAIT_TIME: Duration = Duration::from_secs(60); - loop { - match copy_status { - CopyStatus::Aborted => { - anyhow::bail!("Received abort for copy from {from} to {to}."); + let builder = blob_client.copy(Url::from_str(&source_url)?); + let copy = builder.into_future(); + + let result = copy.await?; + + copy_status = Some(result.copy_status); + loop { + match copy_status.as_ref().expect("we always set it to Some") { + CopyStatus::Aborted => { + anyhow::bail!("Received abort for copy from {from} to {to}."); + } + CopyStatus::Failed => { + anyhow::bail!("Received failure response for copy from {from} to {to}."); + } + CopyStatus::Success => return Ok(()), + CopyStatus::Pending => (), } - CopyStatus::Failed => { - anyhow::bail!("Received failure response for copy from {from} to {to}."); - } - CopyStatus::Success => return Ok(()), - CopyStatus::Pending => (), + // The copy is taking longer. Waiting a second and then re-trying. + // TODO estimate time based on copy_progress and adjust time based on that + tokio::time::sleep(Duration::from_millis(1000)).await; + let properties = blob_client.get_properties().into_future().await?; + let Some(status) = properties.blob.properties.copy_status else { + tracing::warn!("copy_status for copy is None!, from={from}, to={to}"); + return Ok(()); + }; + copy_status = Some(status); } - // The copy is taking longer. Waiting a second and then re-trying. - // TODO estimate time based on copy_progress and adjust time based on that - tokio::time::sleep(Duration::from_millis(1000)).await; - let properties = blob_client.get_properties().into_future().await?; - let Some(status) = properties.blob.properties.copy_status else { - tracing::warn!("copy_status for copy is None!, from={from}, to={to}"); - return Ok(()); - }; - if start_time.elapsed() > MAX_WAIT_TIME { - anyhow::bail!("Copy from from {from} to {to} took longer than limit MAX_WAIT_TIME={}s. copy_pogress={:?}.", - MAX_WAIT_TIME.as_secs_f32(), - properties.blob.properties.copy_progress, - ); - } - copy_status = status; + }; + + tokio::select! { + res = op => res, + _ = cancel.cancelled() => Err(anyhow::Error::new(TimeoutOrCancel::Cancel)), + _ = timeout => { + let e = anyhow::Error::new(TimeoutOrCancel::Timeout); + let e = e.context(format!("Timeout, last status: {copy_status:?}")); + Err(e) + }, } } diff --git a/libs/remote_storage/src/error.rs b/libs/remote_storage/src/error.rs new file mode 100644 index 0000000000..96f044e087 --- /dev/null +++ b/libs/remote_storage/src/error.rs @@ -0,0 +1,181 @@ +/// Reasons for downloads or listings to fail. +#[derive(Debug)] +pub enum DownloadError { + /// Validation or other error happened due to user input. + BadInput(anyhow::Error), + /// The file was not found in the remote storage. + NotFound, + /// A cancellation token aborted the download, typically during + /// tenant detach or process shutdown. + Cancelled, + /// A timeout happened while executing the request. Possible reasons: + /// - stuck tcp connection + /// + /// Concurrency control is not timed within timeout. + Timeout, + /// The file was found in the remote storage, but the download failed. + Other(anyhow::Error), +} + +impl std::fmt::Display for DownloadError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DownloadError::BadInput(e) => { + write!(f, "Failed to download a remote file due to user input: {e}") + } + DownloadError::NotFound => write!(f, "No file found for the remote object id given"), + DownloadError::Cancelled => write!(f, "Cancelled, shutting down"), + DownloadError::Timeout => write!(f, "timeout"), + DownloadError::Other(e) => write!(f, "Failed to download a remote file: {e:?}"), + } + } +} + +impl std::error::Error for DownloadError {} + +impl DownloadError { + /// Returns true if the error should not be retried with backoff + pub fn is_permanent(&self) -> bool { + use DownloadError::*; + match self { + BadInput(_) | NotFound | Cancelled => true, + Timeout | Other(_) => false, + } + } +} + +#[derive(Debug)] +pub enum TimeTravelError { + /// Validation or other error happened due to user input. + BadInput(anyhow::Error), + /// The used remote storage does not have time travel recovery implemented + Unimplemented, + /// The number of versions/deletion markers is above our limit. + TooManyVersions, + /// A cancellation token aborted the process, typically during + /// request closure or process shutdown. + Cancelled, + /// Other errors + Other(anyhow::Error), +} + +impl std::fmt::Display for TimeTravelError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeTravelError::BadInput(e) => { + write!( + f, + "Failed to time travel recover a prefix due to user input: {e}" + ) + } + TimeTravelError::Unimplemented => write!( + f, + "time travel recovery is not implemented for the current storage backend" + ), + TimeTravelError::Cancelled => write!(f, "Cancelled, shutting down"), + TimeTravelError::TooManyVersions => { + write!(f, "Number of versions/delete markers above limit") + } + TimeTravelError::Other(e) => write!(f, "Failed to time travel recover a prefix: {e:?}"), + } + } +} + +impl std::error::Error for TimeTravelError {} + +/// Plain cancelled error. +/// +/// By design this type does not not implement `std::error::Error` so it cannot be put as the root +/// cause of `std::io::Error` or `anyhow::Error`. It should never need to be exposed out of this +/// crate. +/// +/// It exists to implement permit acquiring in `{Download,TimeTravel}Error` and `anyhow::Error` returning +/// operations and ensuring that those get converted to proper versions with just `?`. +#[derive(Debug)] +pub(crate) struct Cancelled; + +impl From for anyhow::Error { + fn from(_: Cancelled) -> Self { + anyhow::Error::new(TimeoutOrCancel::Cancel) + } +} + +impl From for TimeTravelError { + fn from(_: Cancelled) -> Self { + TimeTravelError::Cancelled + } +} + +impl From for TimeoutOrCancel { + fn from(_: Cancelled) -> Self { + TimeoutOrCancel::Cancel + } +} + +impl From for DownloadError { + fn from(_: Cancelled) -> Self { + DownloadError::Cancelled + } +} + +/// This type is used at as the root cause for timeouts and cancellations with `anyhow::Error` returning +/// RemoteStorage methods. +/// +/// For use with `utils::backoff::retry` and `anyhow::Error` returning operations there is +/// `TimeoutOrCancel::caused_by_cancel` method to query "proper form" errors. +#[derive(Debug)] +pub enum TimeoutOrCancel { + Timeout, + Cancel, +} + +impl std::fmt::Display for TimeoutOrCancel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use TimeoutOrCancel::*; + match self { + Timeout => write!(f, "timeout"), + Cancel => write!(f, "cancel"), + } + } +} + +impl std::error::Error for TimeoutOrCancel {} + +impl TimeoutOrCancel { + pub fn caused(error: &anyhow::Error) -> Option<&Self> { + error.root_cause().downcast_ref() + } + + /// Returns true if the error was caused by [`TimeoutOrCancel::Cancel`]. + pub fn caused_by_cancel(error: &anyhow::Error) -> bool { + Self::caused(error).is_some_and(Self::is_cancel) + } + + pub fn is_cancel(&self) -> bool { + matches!(self, TimeoutOrCancel::Cancel) + } + + pub fn is_timeout(&self) -> bool { + matches!(self, TimeoutOrCancel::Timeout) + } +} + +/// This conversion is used when [`crate::support::DownloadStream`] notices a cancellation or +/// timeout to wrap it in an `std::io::Error`. +impl From for std::io::Error { + fn from(value: TimeoutOrCancel) -> Self { + let e = DownloadError::from(value); + std::io::Error::other(e) + } +} + +impl From for DownloadError { + fn from(value: TimeoutOrCancel) -> Self { + use TimeoutOrCancel::*; + + match value { + Timeout => DownloadError::Timeout, + Cancel => DownloadError::Cancelled, + } + } +} diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 5a0b74e406..b0b69f9155 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -10,6 +10,7 @@ #![deny(clippy::undocumented_unsafe_blocks)] mod azure_blob; +mod error; mod local_fs; mod s3_bucket; mod simulate_failures; @@ -21,7 +22,7 @@ use std::{ num::{NonZeroU32, NonZeroUsize}, pin::Pin, sync::Arc, - time::SystemTime, + time::{Duration, SystemTime}, }; use anyhow::{bail, Context}; @@ -41,6 +42,8 @@ pub use self::{ }; use s3_bucket::RequestKind; +pub use error::{DownloadError, TimeTravelError, TimeoutOrCancel}; + /// Currently, sync happens with AWS S3, that has two limits on requests per second: /// ~200 RPS for IAM services /// @@ -158,9 +161,10 @@ pub trait RemoteStorage: Send + Sync + 'static { async fn list_prefixes( &self, prefix: Option<&RemotePath>, + cancel: &CancellationToken, ) -> Result, DownloadError> { let result = self - .list(prefix, ListingMode::WithDelimiter, None) + .list(prefix, ListingMode::WithDelimiter, None, cancel) .await? .prefixes; Ok(result) @@ -182,9 +186,10 @@ pub trait RemoteStorage: Send + Sync + 'static { &self, prefix: Option<&RemotePath>, max_keys: Option, + cancel: &CancellationToken, ) -> Result, DownloadError> { let result = self - .list(prefix, ListingMode::NoDelimiter, max_keys) + .list(prefix, ListingMode::NoDelimiter, max_keys, cancel) .await? .keys; Ok(result) @@ -195,9 +200,13 @@ pub trait RemoteStorage: Send + Sync + 'static { prefix: Option<&RemotePath>, _mode: ListingMode, max_keys: Option, + cancel: &CancellationToken, ) -> Result; /// Streams the local file contents into remote into the remote storage entry. + /// + /// If the operation fails because of timeout or cancellation, the root cause of the error will be + /// set to `TimeoutOrCancel`. async fn upload( &self, from: impl Stream> + Send + Sync + 'static, @@ -206,27 +215,61 @@ pub trait RemoteStorage: Send + Sync + 'static { data_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()>; - /// Streams the remote storage entry contents into the buffered writer given, returns the filled writer. + /// Streams the remote storage entry contents. + /// + /// The returned download stream will obey initial timeout and cancellation signal by erroring + /// on whichever happens first. Only one of the reasons will fail the stream, which is usually + /// enough for `tokio::io::copy_buf` usage. If needed the error can be filtered out. + /// /// Returns the metadata, if any was stored with the file previously. - async fn download(&self, from: &RemotePath) -> Result; + async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result; - /// Streams a given byte range of the remote storage entry contents into the buffered writer given, returns the filled writer. + /// Streams a given byte range of the remote storage entry contents. + /// + /// The returned download stream will obey initial timeout and cancellation signal by erroring + /// on whichever happens first. Only one of the reasons will fail the stream, which is usually + /// enough for `tokio::io::copy_buf` usage. If needed the error can be filtered out. + /// /// Returns the metadata, if any was stored with the file previously. async fn download_byte_range( &self, from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result; - async fn delete(&self, path: &RemotePath) -> anyhow::Result<()>; + /// Delete a single path from remote storage. + /// + /// If the operation fails because of timeout or cancellation, the root cause of the error will be + /// set to `TimeoutOrCancel`. In such situation it is unknown if the deletion went through. + async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()>; - async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()>; + /// Delete a multiple paths from remote storage. + /// + /// If the operation fails because of timeout or cancellation, the root cause of the error will be + /// set to `TimeoutOrCancel`. In such situation it is unknown which deletions, if any, went + /// through. + async fn delete_objects<'a>( + &self, + paths: &'a [RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()>; /// Copy a remote object inside a bucket from one path to another. - async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()>; + async fn copy( + &self, + from: &RemotePath, + to: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()>; /// Resets the content of everything with the given prefix to the given state async fn time_travel_recover( @@ -238,7 +281,13 @@ pub trait RemoteStorage: Send + Sync + 'static { ) -> Result<(), TimeTravelError>; } -pub type DownloadStream = Pin> + Unpin + Send + Sync>>; +/// DownloadStream is sensitive to the timeout and cancellation used with the original +/// [`RemoteStorage::download`] request. The type yields `std::io::Result` to be compatible +/// with `tokio::io::copy_buf`. +// This has 'static because safekeepers do not use cancellation tokens (yet) +pub type DownloadStream = + Pin> + Send + Sync + 'static>>; + pub struct Download { pub download_stream: DownloadStream, /// The last time the file was modified (`last-modified` HTTP header) @@ -257,86 +306,6 @@ impl Debug for Download { } } -#[derive(Debug)] -pub enum DownloadError { - /// Validation or other error happened due to user input. - BadInput(anyhow::Error), - /// The file was not found in the remote storage. - NotFound, - /// A cancellation token aborted the download, typically during - /// tenant detach or process shutdown. - Cancelled, - /// The file was found in the remote storage, but the download failed. - Other(anyhow::Error), -} - -impl std::fmt::Display for DownloadError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DownloadError::BadInput(e) => { - write!(f, "Failed to download a remote file due to user input: {e}") - } - DownloadError::Cancelled => write!(f, "Cancelled, shutting down"), - DownloadError::NotFound => write!(f, "No file found for the remote object id given"), - DownloadError::Other(e) => write!(f, "Failed to download a remote file: {e:?}"), - } - } -} - -impl std::error::Error for DownloadError {} - -impl DownloadError { - /// Returns true if the error should not be retried with backoff - pub fn is_permanent(&self) -> bool { - use DownloadError::*; - match self { - BadInput(_) => true, - NotFound => true, - Cancelled => true, - Other(_) => false, - } - } -} - -#[derive(Debug)] -pub enum TimeTravelError { - /// Validation or other error happened due to user input. - BadInput(anyhow::Error), - /// The used remote storage does not have time travel recovery implemented - Unimplemented, - /// The number of versions/deletion markers is above our limit. - TooManyVersions, - /// A cancellation token aborted the process, typically during - /// request closure or process shutdown. - Cancelled, - /// Other errors - Other(anyhow::Error), -} - -impl std::fmt::Display for TimeTravelError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TimeTravelError::BadInput(e) => { - write!( - f, - "Failed to time travel recover a prefix due to user input: {e}" - ) - } - TimeTravelError::Unimplemented => write!( - f, - "time travel recovery is not implemented for the current storage backend" - ), - TimeTravelError::Cancelled => write!(f, "Cancelled, shutting down"), - TimeTravelError::TooManyVersions => { - write!(f, "Number of versions/delete markers above limit") - } - TimeTravelError::Other(e) => write!(f, "Failed to time travel recover a prefix: {e:?}"), - } - } -} - -impl std::error::Error for TimeTravelError {} - /// Every storage, currently supported. /// Serves as a simple way to pass around the [`RemoteStorage`] without dealing with generics. #[derive(Clone)] @@ -354,12 +323,13 @@ impl GenericRemoteStorage> { prefix: Option<&RemotePath>, mode: ListingMode, max_keys: Option, + cancel: &CancellationToken, ) -> anyhow::Result { match self { - Self::LocalFs(s) => s.list(prefix, mode, max_keys).await, - Self::AwsS3(s) => s.list(prefix, mode, max_keys).await, - Self::AzureBlob(s) => s.list(prefix, mode, max_keys).await, - Self::Unreliable(s) => s.list(prefix, mode, max_keys).await, + Self::LocalFs(s) => s.list(prefix, mode, max_keys, cancel).await, + Self::AwsS3(s) => s.list(prefix, mode, max_keys, cancel).await, + Self::AzureBlob(s) => s.list(prefix, mode, max_keys, cancel).await, + Self::Unreliable(s) => s.list(prefix, mode, max_keys, cancel).await, } } @@ -372,12 +342,13 @@ impl GenericRemoteStorage> { &self, folder: Option<&RemotePath>, max_keys: Option, + cancel: &CancellationToken, ) -> Result, DownloadError> { match self { - Self::LocalFs(s) => s.list_files(folder, max_keys).await, - Self::AwsS3(s) => s.list_files(folder, max_keys).await, - Self::AzureBlob(s) => s.list_files(folder, max_keys).await, - Self::Unreliable(s) => s.list_files(folder, max_keys).await, + Self::LocalFs(s) => s.list_files(folder, max_keys, cancel).await, + Self::AwsS3(s) => s.list_files(folder, max_keys, cancel).await, + Self::AzureBlob(s) => s.list_files(folder, max_keys, cancel).await, + Self::Unreliable(s) => s.list_files(folder, max_keys, cancel).await, } } @@ -387,36 +358,43 @@ impl GenericRemoteStorage> { pub async fn list_prefixes( &self, prefix: Option<&RemotePath>, + cancel: &CancellationToken, ) -> Result, DownloadError> { match self { - Self::LocalFs(s) => s.list_prefixes(prefix).await, - Self::AwsS3(s) => s.list_prefixes(prefix).await, - Self::AzureBlob(s) => s.list_prefixes(prefix).await, - Self::Unreliable(s) => s.list_prefixes(prefix).await, + Self::LocalFs(s) => s.list_prefixes(prefix, cancel).await, + Self::AwsS3(s) => s.list_prefixes(prefix, cancel).await, + Self::AzureBlob(s) => s.list_prefixes(prefix, cancel).await, + Self::Unreliable(s) => s.list_prefixes(prefix, cancel).await, } } + /// See [`RemoteStorage::upload`] pub async fn upload( &self, from: impl Stream> + Send + Sync + 'static, data_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()> { match self { - Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata).await, - Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata).await, - Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata).await, - Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata).await, + Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await, + Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await, + Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await, + Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata, cancel).await, } } - pub async fn download(&self, from: &RemotePath) -> Result { + pub async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result { match self { - Self::LocalFs(s) => s.download(from).await, - Self::AwsS3(s) => s.download(from).await, - Self::AzureBlob(s) => s.download(from).await, - Self::Unreliable(s) => s.download(from).await, + Self::LocalFs(s) => s.download(from, cancel).await, + Self::AwsS3(s) => s.download(from, cancel).await, + Self::AzureBlob(s) => s.download(from, cancel).await, + Self::Unreliable(s) => s.download(from, cancel).await, } } @@ -425,54 +403,72 @@ impl GenericRemoteStorage> { from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result { match self { Self::LocalFs(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive) + s.download_byte_range(from, start_inclusive, end_exclusive, cancel) .await } Self::AwsS3(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive) + s.download_byte_range(from, start_inclusive, end_exclusive, cancel) .await } Self::AzureBlob(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive) + s.download_byte_range(from, start_inclusive, end_exclusive, cancel) .await } Self::Unreliable(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive) + s.download_byte_range(from, start_inclusive, end_exclusive, cancel) .await } } } - pub async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> { + /// See [`RemoteStorage::delete`] + pub async fn delete( + &self, + path: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { match self { - Self::LocalFs(s) => s.delete(path).await, - Self::AwsS3(s) => s.delete(path).await, - Self::AzureBlob(s) => s.delete(path).await, - Self::Unreliable(s) => s.delete(path).await, + Self::LocalFs(s) => s.delete(path, cancel).await, + Self::AwsS3(s) => s.delete(path, cancel).await, + Self::AzureBlob(s) => s.delete(path, cancel).await, + Self::Unreliable(s) => s.delete(path, cancel).await, } } - pub async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> { + /// See [`RemoteStorage::delete_objects`] + pub async fn delete_objects( + &self, + paths: &[RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()> { match self { - Self::LocalFs(s) => s.delete_objects(paths).await, - Self::AwsS3(s) => s.delete_objects(paths).await, - Self::AzureBlob(s) => s.delete_objects(paths).await, - Self::Unreliable(s) => s.delete_objects(paths).await, + Self::LocalFs(s) => s.delete_objects(paths, cancel).await, + Self::AwsS3(s) => s.delete_objects(paths, cancel).await, + Self::AzureBlob(s) => s.delete_objects(paths, cancel).await, + Self::Unreliable(s) => s.delete_objects(paths, cancel).await, } } - pub async fn copy_object(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> { + /// See [`RemoteStorage::copy`] + pub async fn copy_object( + &self, + from: &RemotePath, + to: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { match self { - Self::LocalFs(s) => s.copy(from, to).await, - Self::AwsS3(s) => s.copy(from, to).await, - Self::AzureBlob(s) => s.copy(from, to).await, - Self::Unreliable(s) => s.copy(from, to).await, + Self::LocalFs(s) => s.copy(from, to, cancel).await, + Self::AwsS3(s) => s.copy(from, to, cancel).await, + Self::AzureBlob(s) => s.copy(from, to, cancel).await, + Self::Unreliable(s) => s.copy(from, to, cancel).await, } } + /// See [`RemoteStorage::time_travel_recover`]. pub async fn time_travel_recover( &self, prefix: Option<&RemotePath>, @@ -503,10 +499,11 @@ impl GenericRemoteStorage> { impl GenericRemoteStorage { pub fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result { + let timeout = storage_config.timeout; Ok(match &storage_config.storage { - RemoteStorageKind::LocalFs(root) => { - info!("Using fs root '{root}' as a remote storage"); - Self::LocalFs(LocalFs::new(root.clone())?) + RemoteStorageKind::LocalFs(path) => { + info!("Using fs root '{path}' as a remote storage"); + Self::LocalFs(LocalFs::new(path.clone(), timeout)?) } RemoteStorageKind::AwsS3(s3_config) => { // The profile and access key id are only printed here for debugging purposes, @@ -516,12 +513,12 @@ impl GenericRemoteStorage { std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_else(|_| "".into()); info!("Using s3 bucket '{}' in region '{}' as a remote storage, prefix in bucket: '{:?}', bucket endpoint: '{:?}', profile: {profile}, access_key_id: {access_key_id}", s3_config.bucket_name, s3_config.bucket_region, s3_config.prefix_in_bucket, s3_config.endpoint); - Self::AwsS3(Arc::new(S3Bucket::new(s3_config)?)) + Self::AwsS3(Arc::new(S3Bucket::new(s3_config, timeout)?)) } RemoteStorageKind::AzureContainer(azure_config) => { info!("Using azure container '{}' in region '{}' as a remote storage, prefix in container: '{:?}'", azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container); - Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config)?)) + Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config, timeout)?)) } }) } @@ -530,18 +527,15 @@ impl GenericRemoteStorage { Self::Unreliable(Arc::new(UnreliableWrapper::new(s, fail_first))) } - /// Takes storage object contents and its size and uploads to remote storage, - /// mapping `from_path` to the corresponding remote object id in the storage. - /// - /// The storage object does not have to be present on the `from_path`, - /// this path is used for the remote object id conversion only. + /// See [`RemoteStorage::upload`], which this method calls with `None` as metadata. pub async fn upload_storage_object( &self, from: impl Stream> + Send + Sync + 'static, from_size_bytes: usize, to: &RemotePath, + cancel: &CancellationToken, ) -> anyhow::Result<()> { - self.upload(from, from_size_bytes, to, None) + self.upload(from, from_size_bytes, to, None, cancel) .await .with_context(|| { format!("Failed to upload data of length {from_size_bytes} to storage path {to:?}") @@ -554,10 +548,11 @@ impl GenericRemoteStorage { &self, byte_range: Option<(u64, Option)>, from: &RemotePath, + cancel: &CancellationToken, ) -> Result { match byte_range { - Some((start, end)) => self.download_byte_range(from, start, end).await, - None => self.download(from).await, + Some((start, end)) => self.download_byte_range(from, start, end, cancel).await, + None => self.download(from, cancel).await, } } } @@ -572,6 +567,9 @@ pub struct StorageMetadata(HashMap); pub struct RemoteStorageConfig { /// The storage connection configuration. pub storage: RemoteStorageKind, + /// A common timeout enforced for all requests after concurrency limiter permit has been + /// acquired. + pub timeout: Duration, } /// A kind of a remote storage to connect to, with its connection configuration. @@ -656,6 +654,8 @@ impl Debug for AzureConfig { } impl RemoteStorageConfig { + pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120); + pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result> { let local_path = toml.get("local_path"); let bucket_name = toml.get("bucket_name"); @@ -685,6 +685,27 @@ impl RemoteStorageConfig { .map(|endpoint| parse_toml_string("endpoint", endpoint)) .transpose()?; + let timeout = toml + .get("timeout") + .map(|timeout| { + timeout + .as_str() + .ok_or_else(|| anyhow::Error::msg("timeout was not a string")) + }) + .transpose() + .and_then(|timeout| { + timeout + .map(humantime::parse_duration) + .transpose() + .map_err(anyhow::Error::new) + }) + .context("parse timeout")? + .unwrap_or(Self::DEFAULT_TIMEOUT); + + if timeout < Duration::from_secs(1) { + bail!("timeout was specified as {timeout:?} which is too low"); + } + let storage = match ( local_path, bucket_name, @@ -746,7 +767,7 @@ impl RemoteStorageConfig { } }; - Ok(Some(RemoteStorageConfig { storage })) + Ok(Some(RemoteStorageConfig { storage, timeout })) } } @@ -842,4 +863,24 @@ mod tests { let err = RemotePath::new(Utf8Path::new("/")).expect_err("Should fail on absolute paths"); assert_eq!(err.to_string(), "Path \"/\" is not relative"); } + + #[test] + fn parse_localfs_config_with_timeout() { + let input = "local_path = '.' +timeout = '5s'"; + + let toml = input.parse::().unwrap(); + + let config = RemoteStorageConfig::from_toml(toml.as_item()) + .unwrap() + .expect("it exists"); + + assert_eq!( + config, + RemoteStorageConfig { + storage: RemoteStorageKind::LocalFs(Utf8PathBuf::from(".")), + timeout: Duration::from_secs(5) + } + ); + } } diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index e88111e8e2..6f847cf9d7 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -5,7 +5,12 @@ //! volume is mounted to the local FS. use std::{ - borrow::Cow, future::Future, io::ErrorKind, num::NonZeroU32, pin::Pin, time::SystemTime, + borrow::Cow, + future::Future, + io::ErrorKind, + num::NonZeroU32, + pin::Pin, + time::{Duration, SystemTime}, }; use anyhow::{bail, ensure, Context}; @@ -20,7 +25,9 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken}; use tracing::*; use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty}; -use crate::{Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError}; +use crate::{ + Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError, TimeoutOrCancel, +}; use super::{RemoteStorage, StorageMetadata}; @@ -29,12 +36,13 @@ const LOCAL_FS_TEMP_FILE_SUFFIX: &str = "___temp"; #[derive(Debug, Clone)] pub struct LocalFs { storage_root: Utf8PathBuf, + timeout: Duration, } impl LocalFs { /// Attempts to create local FS storage, along with its root directory. /// Storage root will be created (if does not exist) and transformed into an absolute path (if passed as relative). - pub fn new(mut storage_root: Utf8PathBuf) -> anyhow::Result { + pub fn new(mut storage_root: Utf8PathBuf, timeout: Duration) -> anyhow::Result { if !storage_root.exists() { std::fs::create_dir_all(&storage_root).with_context(|| { format!("Failed to create all directories in the given root path {storage_root:?}") @@ -46,7 +54,10 @@ impl LocalFs { })?; } - Ok(Self { storage_root }) + Ok(Self { + storage_root, + timeout, + }) } // mirrors S3Bucket::s3_object_to_relative_path @@ -157,80 +168,14 @@ impl LocalFs { Ok(files) } -} -impl RemoteStorage for LocalFs { - async fn list( - &self, - prefix: Option<&RemotePath>, - mode: ListingMode, - max_keys: Option, - ) -> Result { - let mut result = Listing::default(); - - if let ListingMode::NoDelimiter = mode { - let keys = self - .list_recursive(prefix) - .await - .map_err(DownloadError::Other)?; - - result.keys = keys - .into_iter() - .filter(|k| { - let path = k.with_base(&self.storage_root); - !path.is_dir() - }) - .collect(); - if let Some(max_keys) = max_keys { - result.keys.truncate(max_keys.get() as usize); - } - - return Ok(result); - } - - let path = match prefix { - Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)), - None => Cow::Borrowed(&self.storage_root), - }; - - let prefixes_to_filter = get_all_files(path.as_ref(), false) - .await - .map_err(DownloadError::Other)?; - - // filter out empty directories to mirror s3 behavior. - for prefix in prefixes_to_filter { - if prefix.is_dir() - && is_directory_empty(&prefix) - .await - .map_err(DownloadError::Other)? - { - continue; - } - - let stripped = prefix - .strip_prefix(&self.storage_root) - .context("Failed to strip prefix") - .and_then(RemotePath::new) - .expect( - "We list files for storage root, hence should be able to remote the prefix", - ); - - if prefix.is_dir() { - result.prefixes.push(stripped); - } else { - result.keys.push(stripped); - } - } - - Ok(result) - } - - async fn upload( + async fn upload0( &self, data: impl Stream> + Send + Sync, data_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()> { let target_file_path = to.with_base(&self.storage_root); create_target_directory(&target_file_path).await?; @@ -265,9 +210,26 @@ impl RemoteStorage for LocalFs { let mut buffer_to_read = data.take(from_size_bytes); // alternatively we could just write the bytes to a file, but local_fs is a testing utility - let bytes_read = io::copy_buf(&mut buffer_to_read, &mut destination) - .await - .with_context(|| { + let copy = io::copy_buf(&mut buffer_to_read, &mut destination); + + let bytes_read = tokio::select! { + biased; + _ = cancel.cancelled() => { + let file = destination.into_inner(); + // wait for the inflight operation(s) to complete so that there could be a next + // attempt right away and our writes are not directed to their file. + file.into_std().await; + + // TODO: leave the temp or not? leaving is probably less racy. enabled truncate at + // least. + fs::remove_file(temp_file_path).await.context("remove temp_file_path after cancellation or timeout")?; + return Err(TimeoutOrCancel::Cancel.into()); + } + read = copy => read, + }; + + let bytes_read = + bytes_read.with_context(|| { format!( "Failed to upload file (write temp) to the local storage at '{temp_file_path}'", ) @@ -299,6 +261,9 @@ impl RemoteStorage for LocalFs { })?; if let Some(storage_metadata) = metadata { + // FIXME: we must not be using metadata much, since this would forget the old metadata + // for new writes? or perhaps metadata is sticky; could consider removing if it's never + // used. let storage_metadata_path = storage_metadata_path(&target_file_path); fs::write( &storage_metadata_path, @@ -315,8 +280,131 @@ impl RemoteStorage for LocalFs { Ok(()) } +} - async fn download(&self, from: &RemotePath) -> Result { +impl RemoteStorage for LocalFs { + async fn list( + &self, + prefix: Option<&RemotePath>, + mode: ListingMode, + max_keys: Option, + cancel: &CancellationToken, + ) -> Result { + let op = async { + let mut result = Listing::default(); + + if let ListingMode::NoDelimiter = mode { + let keys = self + .list_recursive(prefix) + .await + .map_err(DownloadError::Other)?; + + result.keys = keys + .into_iter() + .filter(|k| { + let path = k.with_base(&self.storage_root); + !path.is_dir() + }) + .collect(); + + if let Some(max_keys) = max_keys { + result.keys.truncate(max_keys.get() as usize); + } + + return Ok(result); + } + + let path = match prefix { + Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)), + None => Cow::Borrowed(&self.storage_root), + }; + + let prefixes_to_filter = get_all_files(path.as_ref(), false) + .await + .map_err(DownloadError::Other)?; + + // filter out empty directories to mirror s3 behavior. + for prefix in prefixes_to_filter { + if prefix.is_dir() + && is_directory_empty(&prefix) + .await + .map_err(DownloadError::Other)? + { + continue; + } + + let stripped = prefix + .strip_prefix(&self.storage_root) + .context("Failed to strip prefix") + .and_then(RemotePath::new) + .expect( + "We list files for storage root, hence should be able to remote the prefix", + ); + + if prefix.is_dir() { + result.prefixes.push(stripped); + } else { + result.keys.push(stripped); + } + } + + Ok(result) + }; + + let timeout = async { + tokio::time::sleep(self.timeout).await; + Err(DownloadError::Timeout) + }; + + let cancelled = async { + cancel.cancelled().await; + Err(DownloadError::Cancelled) + }; + + tokio::select! { + res = op => res, + res = timeout => res, + res = cancelled => res, + } + } + + async fn upload( + &self, + data: impl Stream> + Send + Sync, + data_size_bytes: usize, + to: &RemotePath, + metadata: Option, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { + let cancel = cancel.child_token(); + + let op = self.upload0(data, data_size_bytes, to, metadata, &cancel); + let mut op = std::pin::pin!(op); + + // race the upload0 to the timeout; if it goes over, do a graceful shutdown + let (res, timeout) = tokio::select! { + res = &mut op => (res, false), + _ = tokio::time::sleep(self.timeout) => { + cancel.cancel(); + (op.await, true) + } + }; + + match res { + Err(e) if timeout && TimeoutOrCancel::caused_by_cancel(&e) => { + // we caused this cancel (or they happened simultaneously) -- swap it out to + // Timeout + Err(TimeoutOrCancel::Timeout.into()) + } + res => res, + } + } + + async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result { let target_path = from.with_base(&self.storage_root); if file_exists(&target_path).map_err(DownloadError::BadInput)? { let source = ReaderStream::new( @@ -334,6 +422,10 @@ impl RemoteStorage for LocalFs { .read_storage_metadata(&target_path) .await .map_err(DownloadError::Other)?; + + let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone()); + let source = crate::support::DownloadStream::new(cancel_or_timeout, source); + Ok(Download { metadata, last_modified: None, @@ -350,6 +442,7 @@ impl RemoteStorage for LocalFs { from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result { if let Some(end_exclusive) = end_exclusive { if end_exclusive <= start_inclusive { @@ -391,6 +484,9 @@ impl RemoteStorage for LocalFs { let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive); let source = ReaderStream::new(source); + let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone()); + let source = crate::support::DownloadStream::new(cancel_or_timeout, source); + Ok(Download { metadata, last_modified: None, @@ -402,7 +498,7 @@ impl RemoteStorage for LocalFs { } } - async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> { + async fn delete(&self, path: &RemotePath, _cancel: &CancellationToken) -> anyhow::Result<()> { let file_path = path.with_base(&self.storage_root); match fs::remove_file(&file_path).await { Ok(()) => Ok(()), @@ -414,14 +510,23 @@ impl RemoteStorage for LocalFs { } } - async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> { + async fn delete_objects<'a>( + &self, + paths: &'a [RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()> { for path in paths { - self.delete(path).await? + self.delete(path, cancel).await? } Ok(()) } - async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> { + async fn copy( + &self, + from: &RemotePath, + to: &RemotePath, + _cancel: &CancellationToken, + ) -> anyhow::Result<()> { let from_path = from.with_base(&self.storage_root); let to_path = to.with_base(&self.storage_root); create_target_directory(&to_path).await?; @@ -528,8 +633,9 @@ mod fs_tests { remote_storage_path: &RemotePath, expected_metadata: Option<&StorageMetadata>, ) -> anyhow::Result { + let cancel = CancellationToken::new(); let download = storage - .download(remote_storage_path) + .download(remote_storage_path, &cancel) .await .map_err(|e| anyhow::anyhow!("Download failed: {e}"))?; ensure!( @@ -544,16 +650,16 @@ mod fs_tests { #[tokio::test] async fn upload_file() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; - let target_path_1 = upload_dummy_file(&storage, "upload_1", None).await?; + let target_path_1 = upload_dummy_file(&storage, "upload_1", None, &cancel).await?; assert_eq!( storage.list_all().await?, vec![target_path_1.clone()], "Should list a single file after first upload" ); - let target_path_2 = upload_dummy_file(&storage, "upload_2", None).await?; + let target_path_2 = upload_dummy_file(&storage, "upload_2", None, &cancel).await?; assert_eq!( list_files_sorted(&storage).await?, vec![target_path_1.clone(), target_path_2.clone()], @@ -565,7 +671,7 @@ mod fs_tests { #[tokio::test] async fn upload_file_negatives() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let id = RemotePath::new(Utf8Path::new("dummy"))?; let content = Bytes::from_static(b"12345"); @@ -574,34 +680,34 @@ mod fs_tests { // Check that you get an error if the size parameter doesn't match the actual // size of the stream. storage - .upload(content(), 0, &id, None) + .upload(content(), 0, &id, None, &cancel) .await .expect_err("upload with zero size succeeded"); storage - .upload(content(), 4, &id, None) + .upload(content(), 4, &id, None, &cancel) .await .expect_err("upload with too short size succeeded"); storage - .upload(content(), 6, &id, None) + .upload(content(), 6, &id, None, &cancel) .await .expect_err("upload with too large size succeeded"); // Correct size is 5, this should succeed. - storage.upload(content(), 5, &id, None).await?; + storage.upload(content(), 5, &id, None, &cancel).await?; Ok(()) } - fn create_storage() -> anyhow::Result { + fn create_storage() -> anyhow::Result<(LocalFs, CancellationToken)> { let storage_root = tempdir()?.path().to_path_buf(); - LocalFs::new(storage_root) + LocalFs::new(storage_root, Duration::from_secs(120)).map(|s| (s, CancellationToken::new())) } #[tokio::test] async fn download_file() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let upload_name = "upload_1"; - let upload_target = upload_dummy_file(&storage, upload_name, None).await?; + let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?; let contents = read_and_check_metadata(&storage, &upload_target, None).await?; assert_eq!( @@ -611,7 +717,7 @@ mod fs_tests { ); let non_existing_path = "somewhere/else"; - match storage.download(&RemotePath::new(Utf8Path::new(non_existing_path))?).await { + match storage.download(&RemotePath::new(Utf8Path::new(non_existing_path))?, &cancel).await { Err(DownloadError::NotFound) => {} // Should get NotFound for non existing keys other => panic!("Should get a NotFound error when downloading non-existing storage files, but got: {other:?}"), } @@ -620,9 +726,9 @@ mod fs_tests { #[tokio::test] async fn download_file_range_positive() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let upload_name = "upload_1"; - let upload_target = upload_dummy_file(&storage, upload_name, None).await?; + let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?; let full_range_download_contents = read_and_check_metadata(&storage, &upload_target, None).await?; @@ -636,7 +742,12 @@ mod fs_tests { let (first_part_local, second_part_local) = uploaded_bytes.split_at(3); let first_part_download = storage - .download_byte_range(&upload_target, 0, Some(first_part_local.len() as u64)) + .download_byte_range( + &upload_target, + 0, + Some(first_part_local.len() as u64), + &cancel, + ) .await?; assert!( first_part_download.metadata.is_none(), @@ -654,6 +765,7 @@ mod fs_tests { &upload_target, first_part_local.len() as u64, Some((first_part_local.len() + second_part_local.len()) as u64), + &cancel, ) .await?; assert!( @@ -668,7 +780,7 @@ mod fs_tests { ); let suffix_bytes = storage - .download_byte_range(&upload_target, 13, None) + .download_byte_range(&upload_target, 13, None, &cancel) .await? .download_stream; let suffix_bytes = aggregate(suffix_bytes).await?; @@ -676,7 +788,7 @@ mod fs_tests { assert_eq!(upload_name, suffix); let all_bytes = storage - .download_byte_range(&upload_target, 0, None) + .download_byte_range(&upload_target, 0, None, &cancel) .await? .download_stream; let all_bytes = aggregate(all_bytes).await?; @@ -688,9 +800,9 @@ mod fs_tests { #[tokio::test] async fn download_file_range_negative() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let upload_name = "upload_1"; - let upload_target = upload_dummy_file(&storage, upload_name, None).await?; + let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?; let start = 1_000_000_000; let end = start + 1; @@ -699,6 +811,7 @@ mod fs_tests { &upload_target, start, Some(end), // exclusive end + &cancel, ) .await { @@ -715,7 +828,7 @@ mod fs_tests { let end = 234; assert!(start > end, "Should test an incorrect range"); match storage - .download_byte_range(&upload_target, start, Some(end)) + .download_byte_range(&upload_target, start, Some(end), &cancel) .await { Ok(_) => panic!("Should not allow downloading wrong ranges"), @@ -732,15 +845,15 @@ mod fs_tests { #[tokio::test] async fn delete_file() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let upload_name = "upload_1"; - let upload_target = upload_dummy_file(&storage, upload_name, None).await?; + let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?; - storage.delete(&upload_target).await?; + storage.delete(&upload_target, &cancel).await?; assert!(storage.list_all().await?.is_empty()); storage - .delete(&upload_target) + .delete(&upload_target, &cancel) .await .expect("Should allow deleting non-existing storage files"); @@ -749,14 +862,14 @@ mod fs_tests { #[tokio::test] async fn file_with_metadata() -> anyhow::Result<()> { - let storage = create_storage()?; + let (storage, cancel) = create_storage()?; let upload_name = "upload_1"; let metadata = StorageMetadata(HashMap::from([ ("one".to_string(), "1".to_string()), ("two".to_string(), "2".to_string()), ])); let upload_target = - upload_dummy_file(&storage, upload_name, Some(metadata.clone())).await?; + upload_dummy_file(&storage, upload_name, Some(metadata.clone()), &cancel).await?; let full_range_download_contents = read_and_check_metadata(&storage, &upload_target, Some(&metadata)).await?; @@ -770,7 +883,12 @@ mod fs_tests { let (first_part_local, _) = uploaded_bytes.split_at(3); let partial_download_with_metadata = storage - .download_byte_range(&upload_target, 0, Some(first_part_local.len() as u64)) + .download_byte_range( + &upload_target, + 0, + Some(first_part_local.len() as u64), + &cancel, + ) .await?; let first_part_remote = aggregate(partial_download_with_metadata.download_stream).await?; assert_eq!( @@ -791,16 +909,20 @@ mod fs_tests { #[tokio::test] async fn list() -> anyhow::Result<()> { // No delimiter: should recursively list everything - let storage = create_storage()?; - let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?; - let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?; + let (storage, cancel) = create_storage()?; + let child = upload_dummy_file(&storage, "grandparent/parent/child", None, &cancel).await?; + let uncle = upload_dummy_file(&storage, "grandparent/uncle", None, &cancel).await?; - let listing = storage.list(None, ListingMode::NoDelimiter, None).await?; + let listing = storage + .list(None, ListingMode::NoDelimiter, None, &cancel) + .await?; assert!(listing.prefixes.is_empty()); assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec()); // Delimiter: should only go one deep - let listing = storage.list(None, ListingMode::WithDelimiter, None).await?; + let listing = storage + .list(None, ListingMode::WithDelimiter, None, &cancel) + .await?; assert_eq!( listing.prefixes, @@ -814,6 +936,7 @@ mod fs_tests { Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()), ListingMode::WithDelimiter, None, + &cancel, ) .await?; assert_eq!( @@ -826,10 +949,75 @@ mod fs_tests { Ok(()) } + #[tokio::test] + async fn overwrite_shorter_file() -> anyhow::Result<()> { + let (storage, cancel) = create_storage()?; + + let path = RemotePath::new("does/not/matter/file".into())?; + + let body = Bytes::from_static(b"long file contents is long"); + { + let len = body.len(); + let body = + futures::stream::once(futures::future::ready(std::io::Result::Ok(body.clone()))); + storage.upload(body, len, &path, None, &cancel).await?; + } + + let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?; + assert_eq!(body, read); + + let shorter = Bytes::from_static(b"shorter body"); + { + let len = shorter.len(); + let body = + futures::stream::once(futures::future::ready(std::io::Result::Ok(shorter.clone()))); + storage.upload(body, len, &path, None, &cancel).await?; + } + + let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?; + assert_eq!(shorter, read); + Ok(()) + } + + #[tokio::test] + async fn cancelled_upload_can_later_be_retried() -> anyhow::Result<()> { + let (storage, cancel) = create_storage()?; + + let path = RemotePath::new("does/not/matter/file".into())?; + + let body = Bytes::from_static(b"long file contents is long"); + { + let len = body.len(); + let body = + futures::stream::once(futures::future::ready(std::io::Result::Ok(body.clone()))); + let cancel = cancel.child_token(); + cancel.cancel(); + let e = storage + .upload(body, len, &path, None, &cancel) + .await + .unwrap_err(); + + assert!(TimeoutOrCancel::caused_by_cancel(&e)); + } + + { + let len = body.len(); + let body = + futures::stream::once(futures::future::ready(std::io::Result::Ok(body.clone()))); + storage.upload(body, len, &path, None, &cancel).await?; + } + + let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?; + assert_eq!(body, read); + + Ok(()) + } + async fn upload_dummy_file( storage: &LocalFs, name: &str, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result { let from_path = storage .storage_root @@ -851,7 +1039,9 @@ mod fs_tests { let file = tokio_util::io::ReaderStream::new(file); - storage.upload(file, size, &relative_path, metadata).await?; + storage + .upload(file, size, &relative_path, metadata, cancel) + .await?; Ok(relative_path) } diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index dee5750cac..af70dc7ca2 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -11,7 +11,7 @@ use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, - time::SystemTime, + time::{Duration, SystemTime}, }; use anyhow::{anyhow, Context as _}; @@ -46,9 +46,9 @@ use utils::backoff; use super::StorageMetadata; use crate::{ - support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, - RemotePath, RemoteStorage, S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, - REMOTE_STORAGE_PREFIX_SEPARATOR, + error::Cancelled, support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, + Listing, ListingMode, RemotePath, RemoteStorage, S3Config, TimeTravelError, TimeoutOrCancel, + MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR, }; pub(super) mod metrics; @@ -63,6 +63,8 @@ pub struct S3Bucket { prefix_in_bucket: Option, max_keys_per_list_response: Option, concurrency_limiter: ConcurrencyLimiter, + // Per-request timeout. Accessible for tests. + pub timeout: Duration, } struct GetObjectRequest { @@ -72,7 +74,7 @@ struct GetObjectRequest { } impl S3Bucket { /// Creates the S3 storage, errors if incorrect AWS S3 configuration provided. - pub fn new(aws_config: &S3Config) -> anyhow::Result { + pub fn new(aws_config: &S3Config, timeout: Duration) -> anyhow::Result { tracing::debug!( "Creating s3 remote storage for S3 bucket {}", aws_config.bucket_name @@ -152,6 +154,7 @@ impl S3Bucket { max_keys_per_list_response: aws_config.max_keys_per_list_response, prefix_in_bucket, concurrency_limiter: ConcurrencyLimiter::new(aws_config.concurrency_limit.get()), + timeout, }) } @@ -185,40 +188,55 @@ impl S3Bucket { } } - async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> { + async fn permit( + &self, + kind: RequestKind, + cancel: &CancellationToken, + ) -> Result, Cancelled> { let started_at = start_counting_cancelled_wait(kind); - let permit = self - .concurrency_limiter - .acquire(kind) - .await - .expect("semaphore is never closed"); + let acquire = self.concurrency_limiter.acquire(kind); + + let permit = tokio::select! { + permit = acquire => permit.expect("semaphore is never closed"), + _ = cancel.cancelled() => return Err(Cancelled), + }; let started_at = ScopeGuard::into_inner(started_at); metrics::BUCKET_METRICS .wait_seconds .observe_elapsed(kind, started_at); - permit + Ok(permit) } - async fn owned_permit(&self, kind: RequestKind) -> tokio::sync::OwnedSemaphorePermit { + async fn owned_permit( + &self, + kind: RequestKind, + cancel: &CancellationToken, + ) -> Result { let started_at = start_counting_cancelled_wait(kind); - let permit = self - .concurrency_limiter - .acquire_owned(kind) - .await - .expect("semaphore is never closed"); + let acquire = self.concurrency_limiter.acquire_owned(kind); + + let permit = tokio::select! { + permit = acquire => permit.expect("semaphore is never closed"), + _ = cancel.cancelled() => return Err(Cancelled), + }; let started_at = ScopeGuard::into_inner(started_at); metrics::BUCKET_METRICS .wait_seconds .observe_elapsed(kind, started_at); - permit + Ok(permit) } - async fn download_object(&self, request: GetObjectRequest) -> Result { + async fn download_object( + &self, + request: GetObjectRequest, + cancel: &CancellationToken, + ) -> Result { let kind = RequestKind::Get; - let permit = self.owned_permit(kind).await; + + let permit = self.owned_permit(kind, cancel).await?; let started_at = start_measuring_requests(kind); @@ -228,8 +246,13 @@ impl S3Bucket { .bucket(request.bucket) .key(request.key) .set_range(request.range) - .send() - .await; + .send(); + + let get_object = tokio::select! { + res = get_object => res, + _ = tokio::time::sleep(self.timeout) => return Err(DownloadError::Timeout), + _ = cancel.cancelled() => return Err(DownloadError::Cancelled), + }; let started_at = ScopeGuard::into_inner(started_at); @@ -259,6 +282,10 @@ impl S3Bucket { } }; + // even if we would have no timeout left, continue anyways. the caller can decide to ignore + // the errors considering timeouts and cancellation. + let remaining = self.timeout.saturating_sub(started_at.elapsed()); + let metadata = object_output.metadata().cloned().map(StorageMetadata); let etag = object_output.e_tag; let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok()); @@ -268,6 +295,9 @@ impl S3Bucket { let body = PermitCarrying::new(permit, body); let body = TimedDownload::new(started_at, body); + let cancel_or_timeout = crate::support::cancel_or_timeout(remaining, cancel.clone()); + let body = crate::support::DownloadStream::new(cancel_or_timeout, body); + Ok(Download { metadata, etag, @@ -278,33 +308,44 @@ impl S3Bucket { async fn delete_oids( &self, - kind: RequestKind, + _permit: &tokio::sync::SemaphorePermit<'_>, delete_objects: &[ObjectIdentifier], + cancel: &CancellationToken, ) -> anyhow::Result<()> { + let kind = RequestKind::Delete; + let mut cancel = std::pin::pin!(cancel.cancelled()); + for chunk in delete_objects.chunks(MAX_KEYS_PER_DELETE) { let started_at = start_measuring_requests(kind); - let resp = self + let req = self .client .delete_objects() .bucket(self.bucket_name.clone()) .delete( Delete::builder() .set_objects(Some(chunk.to_vec())) - .build()?, + .build() + .context("build request")?, ) - .send() - .await; + .send(); + + let resp = tokio::select! { + resp = req => resp, + _ = tokio::time::sleep(self.timeout) => return Err(TimeoutOrCancel::Timeout.into()), + _ = &mut cancel => return Err(TimeoutOrCancel::Cancel.into()), + }; let started_at = ScopeGuard::into_inner(started_at); metrics::BUCKET_METRICS .req_seconds .observe_elapsed(kind, &resp, started_at); - let resp = resp?; + let resp = resp.context("request deletion")?; metrics::BUCKET_METRICS .deleted_objects_total .inc_by(chunk.len() as u64); + if let Some(errors) = resp.errors { // Log a bounded number of the errors within the response: // these requests can carry 1000 keys so logging each one @@ -320,9 +361,10 @@ impl S3Bucket { ); } - return Err(anyhow::format_err!( - "Failed to delete {} objects", - errors.len() + return Err(anyhow::anyhow!( + "Failed to delete {}/{} objects", + errors.len(), + chunk.len(), )); } } @@ -410,6 +452,7 @@ impl RemoteStorage for S3Bucket { prefix: Option<&RemotePath>, mode: ListingMode, max_keys: Option, + cancel: &CancellationToken, ) -> Result { let kind = RequestKind::List; // s3 sdk wants i32 @@ -431,10 +474,11 @@ impl RemoteStorage for S3Bucket { p }); + let _permit = self.permit(kind, cancel).await?; + let mut continuation_token = None; loop { - let _guard = self.permit(kind).await; let started_at = start_measuring_requests(kind); // min of two Options, returning Some if one is value and another is @@ -456,9 +500,15 @@ impl RemoteStorage for S3Bucket { request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string()); } - let response = request - .send() - .await + let request = request.send(); + + let response = tokio::select! { + res = request => res, + _ = tokio::time::sleep(self.timeout) => return Err(DownloadError::Timeout), + _ = cancel.cancelled() => return Err(DownloadError::Cancelled), + }; + + let response = response .context("Failed to list S3 prefixes") .map_err(DownloadError::Other); @@ -511,16 +561,17 @@ impl RemoteStorage for S3Bucket { from_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()> { let kind = RequestKind::Put; - let _guard = self.permit(kind).await; + let _permit = self.permit(kind, cancel).await?; let started_at = start_measuring_requests(kind); let body = Body::wrap_stream(from); let bytes_stream = ByteStream::new(SdkBody::from_body_0_4(body)); - let res = self + let upload = self .client .put_object() .bucket(self.bucket_name.clone()) @@ -528,22 +579,40 @@ impl RemoteStorage for S3Bucket { .set_metadata(metadata.map(|m| m.0)) .content_length(from_size_bytes.try_into()?) .body(bytes_stream) - .send() - .await; + .send(); - let started_at = ScopeGuard::into_inner(started_at); - metrics::BUCKET_METRICS - .req_seconds - .observe_elapsed(kind, &res, started_at); + let upload = tokio::time::timeout(self.timeout, upload); - res?; + let res = tokio::select! { + res = upload => res, + _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()), + }; - Ok(()) + if let Ok(inner) = &res { + // do not incl. timeouts as errors in metrics but cancellations + let started_at = ScopeGuard::into_inner(started_at); + metrics::BUCKET_METRICS + .req_seconds + .observe_elapsed(kind, inner, started_at); + } + + match res { + Ok(Ok(_put)) => Ok(()), + Ok(Err(sdk)) => Err(sdk.into()), + Err(_timeout) => Err(TimeoutOrCancel::Timeout.into()), + } } - async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> { + async fn copy( + &self, + from: &RemotePath, + to: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { let kind = RequestKind::Copy; - let _guard = self.permit(kind).await; + let _permit = self.permit(kind, cancel).await?; + + let timeout = tokio::time::sleep(self.timeout); let started_at = start_measuring_requests(kind); @@ -554,14 +623,19 @@ impl RemoteStorage for S3Bucket { self.relative_path_to_s3_object(from) ); - let res = self + let op = self .client .copy_object() .bucket(self.bucket_name.clone()) .key(self.relative_path_to_s3_object(to)) .copy_source(copy_source) - .send() - .await; + .send(); + + let res = tokio::select! { + res = op => res, + _ = timeout => return Err(TimeoutOrCancel::Timeout.into()), + _ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()), + }; let started_at = ScopeGuard::into_inner(started_at); metrics::BUCKET_METRICS @@ -573,14 +647,21 @@ impl RemoteStorage for S3Bucket { Ok(()) } - async fn download(&self, from: &RemotePath) -> Result { + async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result { // if prefix is not none then download file `prefix/from` // if prefix is none then download file `from` - self.download_object(GetObjectRequest { - bucket: self.bucket_name.clone(), - key: self.relative_path_to_s3_object(from), - range: None, - }) + self.download_object( + GetObjectRequest { + bucket: self.bucket_name.clone(), + key: self.relative_path_to_s3_object(from), + range: None, + }, + cancel, + ) .await } @@ -589,6 +670,7 @@ impl RemoteStorage for S3Bucket { from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result { // S3 accepts ranges as https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 // and needs both ends to be exclusive @@ -598,31 +680,39 @@ impl RemoteStorage for S3Bucket { None => format!("bytes={start_inclusive}-"), }); - self.download_object(GetObjectRequest { - bucket: self.bucket_name.clone(), - key: self.relative_path_to_s3_object(from), - range, - }) + self.download_object( + GetObjectRequest { + bucket: self.bucket_name.clone(), + key: self.relative_path_to_s3_object(from), + range, + }, + cancel, + ) .await } - async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> { - let kind = RequestKind::Delete; - let _guard = self.permit(kind).await; + async fn delete_objects<'a>( + &self, + paths: &'a [RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()> { + let kind = RequestKind::Delete; + let permit = self.permit(kind, cancel).await?; let mut delete_objects = Vec::with_capacity(paths.len()); for path in paths { let obj_id = ObjectIdentifier::builder() .set_key(Some(self.relative_path_to_s3_object(path))) - .build()?; + .build() + .context("convert path to oid")?; delete_objects.push(obj_id); } - self.delete_oids(kind, &delete_objects).await + self.delete_oids(&permit, &delete_objects, cancel).await } - async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> { + async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> { let paths = std::array::from_ref(path); - self.delete_objects(paths).await + self.delete_objects(paths, cancel).await } async fn time_travel_recover( @@ -633,7 +723,7 @@ impl RemoteStorage for S3Bucket { cancel: &CancellationToken, ) -> Result<(), TimeTravelError> { let kind = RequestKind::TimeTravel; - let _guard = self.permit(kind).await; + let permit = self.permit(kind, cancel).await?; let timestamp = DateTime::from(timestamp); let done_if_after = DateTime::from(done_if_after); @@ -647,7 +737,7 @@ impl RemoteStorage for S3Bucket { let warn_threshold = 3; let max_retries = 10; - let is_permanent = |_e: &_| false; + let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled); let mut key_marker = None; let mut version_id_marker = None; @@ -656,15 +746,19 @@ impl RemoteStorage for S3Bucket { loop { let response = backoff::retry( || async { - self.client + let op = self + .client .list_object_versions() .bucket(self.bucket_name.clone()) .set_prefix(prefix.clone()) .set_key_marker(key_marker.clone()) .set_version_id_marker(version_id_marker.clone()) - .send() - .await - .map_err(|e| TimeTravelError::Other(e.into())) + .send(); + + tokio::select! { + res = op => res.map_err(|e| TimeTravelError::Other(e.into())), + _ = cancel.cancelled() => Err(TimeTravelError::Cancelled), + } }, is_permanent, warn_threshold, @@ -786,14 +880,18 @@ impl RemoteStorage for S3Bucket { backoff::retry( || async { - self.client + let op = self + .client .copy_object() .bucket(self.bucket_name.clone()) .key(key) .copy_source(&source_id) - .send() - .await - .map_err(|e| TimeTravelError::Other(e.into())) + .send(); + + tokio::select! { + res = op => res.map_err(|e| TimeTravelError::Other(e.into())), + _ = cancel.cancelled() => Err(TimeTravelError::Cancelled), + } }, is_permanent, warn_threshold, @@ -824,10 +922,18 @@ impl RemoteStorage for S3Bucket { let oid = ObjectIdentifier::builder() .key(key.to_owned()) .build() - .map_err(|e| TimeTravelError::Other(anyhow::Error::new(e)))?; - self.delete_oids(kind, &[oid]) + .map_err(|e| TimeTravelError::Other(e.into()))?; + + self.delete_oids(&permit, &[oid], cancel) .await - .map_err(TimeTravelError::Other)?; + .map_err(|e| { + // delete_oid0 will use TimeoutOrCancel + if TimeoutOrCancel::caused_by_cancel(&e) { + TimeTravelError::Cancelled + } else { + TimeTravelError::Other(e) + } + })?; } } } @@ -963,7 +1069,8 @@ mod tests { concurrency_limit: NonZeroUsize::new(100).unwrap(), max_keys_per_list_response: Some(5), }; - let storage = S3Bucket::new(&config).expect("remote storage init"); + let storage = + S3Bucket::new(&config, std::time::Duration::ZERO).expect("remote storage init"); for (test_path_idx, test_path) in all_paths.iter().enumerate() { let result = storage.relative_path_to_s3_object(test_path); let expected = expected_outputs[prefix_idx][test_path_idx]; diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 3dfa16b64e..f5344d3ae2 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -90,11 +90,16 @@ impl UnreliableWrapper { } } - async fn delete_inner(&self, path: &RemotePath, attempt: bool) -> anyhow::Result<()> { + async fn delete_inner( + &self, + path: &RemotePath, + attempt: bool, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { if attempt { self.attempt(RemoteOp::Delete(path.clone()))?; } - self.inner.delete(path).await + self.inner.delete(path, cancel).await } } @@ -105,20 +110,22 @@ impl RemoteStorage for UnreliableWrapper { async fn list_prefixes( &self, prefix: Option<&RemotePath>, + cancel: &CancellationToken, ) -> Result, DownloadError> { self.attempt(RemoteOp::ListPrefixes(prefix.cloned())) .map_err(DownloadError::Other)?; - self.inner.list_prefixes(prefix).await + self.inner.list_prefixes(prefix, cancel).await } async fn list_files( &self, folder: Option<&RemotePath>, max_keys: Option, + cancel: &CancellationToken, ) -> Result, DownloadError> { self.attempt(RemoteOp::ListPrefixes(folder.cloned())) .map_err(DownloadError::Other)?; - self.inner.list_files(folder, max_keys).await + self.inner.list_files(folder, max_keys, cancel).await } async fn list( @@ -126,10 +133,11 @@ impl RemoteStorage for UnreliableWrapper { prefix: Option<&RemotePath>, mode: ListingMode, max_keys: Option, + cancel: &CancellationToken, ) -> Result { self.attempt(RemoteOp::ListPrefixes(prefix.cloned())) .map_err(DownloadError::Other)?; - self.inner.list(prefix, mode, max_keys).await + self.inner.list(prefix, mode, max_keys, cancel).await } async fn upload( @@ -140,15 +148,22 @@ impl RemoteStorage for UnreliableWrapper { data_size_bytes: usize, to: &RemotePath, metadata: Option, + cancel: &CancellationToken, ) -> anyhow::Result<()> { self.attempt(RemoteOp::Upload(to.clone()))?; - self.inner.upload(data, data_size_bytes, to, metadata).await + self.inner + .upload(data, data_size_bytes, to, metadata, cancel) + .await } - async fn download(&self, from: &RemotePath) -> Result { + async fn download( + &self, + from: &RemotePath, + cancel: &CancellationToken, + ) -> Result { self.attempt(RemoteOp::Download(from.clone())) .map_err(DownloadError::Other)?; - self.inner.download(from).await + self.inner.download(from, cancel).await } async fn download_byte_range( @@ -156,6 +171,7 @@ impl RemoteStorage for UnreliableWrapper { from: &RemotePath, start_inclusive: u64, end_exclusive: Option, + cancel: &CancellationToken, ) -> Result { // Note: We treat any download_byte_range as an "attempt" of the same // operation. We don't pay attention to the ranges. That's good enough @@ -163,20 +179,24 @@ impl RemoteStorage for UnreliableWrapper { self.attempt(RemoteOp::Download(from.clone())) .map_err(DownloadError::Other)?; self.inner - .download_byte_range(from, start_inclusive, end_exclusive) + .download_byte_range(from, start_inclusive, end_exclusive, cancel) .await } - async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> { - self.delete_inner(path, true).await + async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> { + self.delete_inner(path, true, cancel).await } - async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> { + async fn delete_objects<'a>( + &self, + paths: &'a [RemotePath], + cancel: &CancellationToken, + ) -> anyhow::Result<()> { self.attempt(RemoteOp::DeleteObjects(paths.to_vec()))?; let mut error_counter = 0; for path in paths { // Dont record attempt because it was already recorded above - if (self.delete_inner(path, false).await).is_err() { + if (self.delete_inner(path, false, cancel).await).is_err() { error_counter += 1; } } @@ -189,11 +209,16 @@ impl RemoteStorage for UnreliableWrapper { Ok(()) } - async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> { + async fn copy( + &self, + from: &RemotePath, + to: &RemotePath, + cancel: &CancellationToken, + ) -> anyhow::Result<()> { // copy is equivalent to download + upload self.attempt(RemoteOp::Download(from.clone()))?; self.attempt(RemoteOp::Upload(to.clone()))?; - self.inner.copy_object(from, to).await + self.inner.copy_object(from, to, cancel).await } async fn time_travel_recover( diff --git a/libs/remote_storage/src/support.rs b/libs/remote_storage/src/support.rs index 4688a484a5..20f193c6c8 100644 --- a/libs/remote_storage/src/support.rs +++ b/libs/remote_storage/src/support.rs @@ -1,9 +1,15 @@ use std::{ + future::Future, pin::Pin, task::{Context, Poll}, + time::Duration, }; +use bytes::Bytes; use futures_util::Stream; +use tokio_util::sync::CancellationToken; + +use crate::TimeoutOrCancel; pin_project_lite::pin_project! { /// An `AsyncRead` adapter which carries a permit for the lifetime of the value. @@ -31,3 +37,133 @@ impl Stream for PermitCarrying { self.inner.size_hint() } } + +pin_project_lite::pin_project! { + pub(crate) struct DownloadStream { + hit: bool, + #[pin] + cancellation: F, + #[pin] + inner: S, + } +} + +impl DownloadStream { + pub(crate) fn new(cancellation: F, inner: S) -> Self { + Self { + cancellation, + hit: false, + inner, + } + } +} + +/// See documentation on [`crate::DownloadStream`] on rationale why `std::io::Error` is used. +impl Stream for DownloadStream +where + std::io::Error: From, + F: Future, + S: Stream>, +{ + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if !*this.hit { + if let Poll::Ready(e) = this.cancellation.poll(cx) { + *this.hit = true; + let e = Err(std::io::Error::from(e)); + return Poll::Ready(Some(e)); + } + } + + this.inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// Fires only on the first cancel or timeout, not on both. +pub(crate) async fn cancel_or_timeout( + timeout: Duration, + cancel: CancellationToken, +) -> TimeoutOrCancel { + tokio::select! { + _ = tokio::time::sleep(timeout) => TimeoutOrCancel::Timeout, + _ = cancel.cancelled() => TimeoutOrCancel::Cancel, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::DownloadError; + use futures::stream::StreamExt; + + #[tokio::test(start_paused = true)] + async fn cancelled_download_stream() { + let inner = futures::stream::pending(); + let timeout = Duration::from_secs(120); + let cancel = CancellationToken::new(); + + let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner); + let mut stream = std::pin::pin!(stream); + + let mut first = stream.next(); + + tokio::select! { + _ = &mut first => unreachable!("we haven't yet cancelled nor is timeout passed"), + _ = tokio::time::sleep(Duration::from_secs(1)) => {}, + } + + cancel.cancel(); + + let e = first.await.expect("there must be some").unwrap_err(); + assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}"); + let inner = e.get_ref().expect("inner should be set"); + assert!( + inner + .downcast_ref::() + .is_some_and(|e| matches!(e, DownloadError::Cancelled)), + "{inner:?}" + ); + + tokio::select! { + _ = stream.next() => unreachable!("no timeout ever happens as we were already cancelled"), + _ = tokio::time::sleep(Duration::from_secs(121)) => {}, + } + } + + #[tokio::test(start_paused = true)] + async fn timeouted_download_stream() { + let inner = futures::stream::pending(); + let timeout = Duration::from_secs(120); + let cancel = CancellationToken::new(); + + let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner); + let mut stream = std::pin::pin!(stream); + + // because the stream uses 120s timeout we are paused, we advance to 120s right away. + let first = stream.next(); + + let e = first.await.expect("there must be some").unwrap_err(); + assert!(matches!(e.kind(), std::io::ErrorKind::Other), "{e:?}"); + let inner = e.get_ref().expect("inner should be set"); + assert!( + inner + .downcast_ref::() + .is_some_and(|e| matches!(e, DownloadError::Timeout)), + "{inner:?}" + ); + + cancel.cancel(); + + tokio::select! { + _ = stream.next() => unreachable!("no cancellation ever happens because we already timed out"), + _ = tokio::time::sleep(Duration::from_secs(121)) => {}, + } + } +} diff --git a/libs/remote_storage/tests/common/mod.rs b/libs/remote_storage/tests/common/mod.rs index bca117ed1a..da9dc08d8d 100644 --- a/libs/remote_storage/tests/common/mod.rs +++ b/libs/remote_storage/tests/common/mod.rs @@ -10,6 +10,7 @@ use futures::stream::Stream; use once_cell::sync::OnceCell; use remote_storage::{Download, GenericRemoteStorage, RemotePath}; use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; static LOGGING_DONE: OnceCell<()> = OnceCell::new(); @@ -58,8 +59,12 @@ pub(crate) async fn upload_simple_remote_data( ) -> ControlFlow, HashSet> { info!("Creating {upload_tasks_count} remote files"); let mut upload_tasks = JoinSet::new(); + let cancel = CancellationToken::new(); + for i in 1..upload_tasks_count + 1 { let task_client = Arc::clone(client); + let cancel = cancel.clone(); + upload_tasks.spawn(async move { let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i)); let blob_path = RemotePath::new( @@ -69,7 +74,9 @@ pub(crate) async fn upload_simple_remote_data( debug!("Creating remote item {i} at path {blob_path:?}"); let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into()); - task_client.upload(data, len, &blob_path, None).await?; + task_client + .upload(data, len, &blob_path, None, &cancel) + .await?; Ok::<_, anyhow::Error>(blob_path) }); @@ -107,13 +114,15 @@ pub(crate) async fn cleanup( "Removing {} objects from the remote storage during cleanup", objects_to_delete.len() ); + let cancel = CancellationToken::new(); let mut delete_tasks = JoinSet::new(); for object_to_delete in objects_to_delete { let task_client = Arc::clone(client); + let cancel = cancel.clone(); delete_tasks.spawn(async move { debug!("Deleting remote item at path {object_to_delete:?}"); task_client - .delete(&object_to_delete) + .delete(&object_to_delete, &cancel) .await .with_context(|| format!("{object_to_delete:?} removal")) }); @@ -141,8 +150,12 @@ pub(crate) async fn upload_remote_data( ) -> ControlFlow { info!("Creating {upload_tasks_count} remote files"); let mut upload_tasks = JoinSet::new(); + let cancel = CancellationToken::new(); + for i in 1..upload_tasks_count + 1 { let task_client = Arc::clone(client); + let cancel = cancel.clone(); + upload_tasks.spawn(async move { let prefix = format!("{base_prefix_str}/sub_prefix_{i}/"); let blob_prefix = RemotePath::new(Utf8Path::new(&prefix)) @@ -152,7 +165,9 @@ pub(crate) async fn upload_remote_data( let (data, data_len) = upload_stream(format!("remote blob data {i}").into_bytes().into()); - task_client.upload(data, data_len, &blob_path, None).await?; + task_client + .upload(data, data_len, &blob_path, None, &cancel) + .await?; Ok::<_, anyhow::Error>((blob_prefix, blob_path)) }); diff --git a/libs/remote_storage/tests/common/tests.rs b/libs/remote_storage/tests/common/tests.rs index 6d062f3898..72f6f956e0 100644 --- a/libs/remote_storage/tests/common/tests.rs +++ b/libs/remote_storage/tests/common/tests.rs @@ -4,6 +4,7 @@ use remote_storage::RemotePath; use std::sync::Arc; use std::{collections::HashSet, num::NonZeroU32}; use test_context::test_context; +use tokio_util::sync::CancellationToken; use tracing::debug; use crate::common::{download_to_vec, upload_stream, wrap_stream}; @@ -45,13 +46,15 @@ async fn pagination_should_work(ctx: &mut MaybeEnabledStorageWithTestBlobs) -> a } }; + let cancel = CancellationToken::new(); + let test_client = Arc::clone(&ctx.enabled.client); let expected_remote_prefixes = ctx.remote_prefixes.clone(); let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix)) .context("common_prefix construction")?; let root_remote_prefixes = test_client - .list_prefixes(None) + .list_prefixes(None, &cancel) .await .context("client list root prefixes failure")? .into_iter() @@ -62,7 +65,7 @@ async fn pagination_should_work(ctx: &mut MaybeEnabledStorageWithTestBlobs) -> a ); let nested_remote_prefixes = test_client - .list_prefixes(Some(&base_prefix)) + .list_prefixes(Some(&base_prefix), &cancel) .await .context("client list nested prefixes failure")? .into_iter() @@ -99,11 +102,12 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a anyhow::bail!("S3 init failed: {e:?}") } }; + let cancel = CancellationToken::new(); let test_client = Arc::clone(&ctx.enabled.client); let base_prefix = RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?; let root_files = test_client - .list_files(None, None) + .list_files(None, None, &cancel) .await .context("client list root files failure")? .into_iter() @@ -117,13 +121,13 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a // Test that max_keys limit works. In total there are about 21 files (see // upload_simple_remote_data call in test_real_s3.rs). let limited_root_files = test_client - .list_files(None, Some(NonZeroU32::new(2).unwrap())) + .list_files(None, Some(NonZeroU32::new(2).unwrap()), &cancel) .await .context("client list root files failure")?; assert_eq!(limited_root_files.len(), 2); let nested_remote_files = test_client - .list_files(Some(&base_prefix), None) + .list_files(Some(&base_prefix), None, &cancel) .await .context("client list nested files failure")? .into_iter() @@ -150,12 +154,17 @@ async fn delete_non_exising_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Resu MaybeEnabledStorage::Disabled => return Ok(()), }; + let cancel = CancellationToken::new(); + let path = RemotePath::new(Utf8Path::new( format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(), )) .with_context(|| "RemotePath conversion")?; - ctx.client.delete(&path).await.expect("should succeed"); + ctx.client + .delete(&path, &cancel) + .await + .expect("should succeed"); Ok(()) } @@ -168,6 +177,8 @@ async fn delete_objects_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<( MaybeEnabledStorage::Disabled => return Ok(()), }; + let cancel = CancellationToken::new(); + let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str())) .with_context(|| "RemotePath conversion")?; @@ -178,21 +189,21 @@ async fn delete_objects_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<( .with_context(|| "RemotePath conversion")?; let (data, len) = upload_stream("remote blob data1".as_bytes().into()); - ctx.client.upload(data, len, &path1, None).await?; + ctx.client.upload(data, len, &path1, None, &cancel).await?; let (data, len) = upload_stream("remote blob data2".as_bytes().into()); - ctx.client.upload(data, len, &path2, None).await?; + ctx.client.upload(data, len, &path2, None, &cancel).await?; let (data, len) = upload_stream("remote blob data3".as_bytes().into()); - ctx.client.upload(data, len, &path3, None).await?; + ctx.client.upload(data, len, &path3, None, &cancel).await?; - ctx.client.delete_objects(&[path1, path2]).await?; + ctx.client.delete_objects(&[path1, path2], &cancel).await?; - let prefixes = ctx.client.list_prefixes(None).await?; + let prefixes = ctx.client.list_prefixes(None, &cancel).await?; assert_eq!(prefixes.len(), 1); - ctx.client.delete_objects(&[path3]).await?; + ctx.client.delete_objects(&[path3], &cancel).await?; Ok(()) } @@ -204,6 +215,8 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< return Ok(()); }; + let cancel = CancellationToken::new(); + let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str())) .with_context(|| "RemotePath conversion")?; @@ -211,47 +224,56 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< let (data, len) = wrap_stream(orig.clone()); - ctx.client.upload(data, len, &path, None).await?; + ctx.client.upload(data, len, &path, None, &cancel).await?; // Normal download request - let dl = ctx.client.download(&path).await?; + let dl = ctx.client.download(&path, &cancel).await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); // Full range (end specified) let dl = ctx .client - .download_byte_range(&path, 0, Some(len as u64)) + .download_byte_range(&path, 0, Some(len as u64), &cancel) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); // partial range (end specified) - let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?; + let dl = ctx + .client + .download_byte_range(&path, 4, Some(10), &cancel) + .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[4..10]); // partial range (end beyond real end) let dl = ctx .client - .download_byte_range(&path, 8, Some(len as u64 * 100)) + .download_byte_range(&path, 8, Some(len as u64 * 100), &cancel) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[8..]); // Partial range (end unspecified) - let dl = ctx.client.download_byte_range(&path, 4, None).await?; + let dl = ctx + .client + .download_byte_range(&path, 4, None, &cancel) + .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[4..]); // Full range (end unspecified) - let dl = ctx.client.download_byte_range(&path, 0, None).await?; + let dl = ctx + .client + .download_byte_range(&path, 0, None, &cancel) + .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); debug!("Cleanup: deleting file at path {path:?}"); ctx.client - .delete(&path) + .delete(&path, &cancel) .await .with_context(|| format!("{path:?} removal"))?; @@ -265,6 +287,8 @@ async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> { return Ok(()); }; + let cancel = CancellationToken::new(); + let path = RemotePath::new(Utf8Path::new( format!("{}/file_to_copy", ctx.base_prefix).as_str(), )) @@ -278,18 +302,18 @@ async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> { let (data, len) = wrap_stream(orig.clone()); - ctx.client.upload(data, len, &path, None).await?; + ctx.client.upload(data, len, &path, None, &cancel).await?; // Normal download request - ctx.client.copy_object(&path, &path_dest).await?; + ctx.client.copy_object(&path, &path_dest, &cancel).await?; - let dl = ctx.client.download(&path_dest).await?; + let dl = ctx.client.download(&path_dest, &cancel).await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); debug!("Cleanup: deleting file at path {path:?}"); ctx.client - .delete_objects(&[path.clone(), path_dest.clone()]) + .delete_objects(&[path.clone(), path_dest.clone()], &cancel) .await .with_context(|| format!("{path:?} removal"))?; diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index 6f9a1ec6f7..6adddf52a9 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -1,9 +1,9 @@ -use std::collections::HashSet; use std::env; use std::num::NonZeroUsize; use std::ops::ControlFlow; use std::sync::Arc; use std::time::UNIX_EPOCH; +use std::{collections::HashSet, time::Duration}; use anyhow::Context; use remote_storage::{ @@ -39,6 +39,17 @@ impl EnabledAzure { base_prefix: BASE_PREFIX, } } + + #[allow(unused)] // this will be needed when moving the timeout integration tests back + fn configure_request_timeout(&mut self, timeout: Duration) { + match Arc::get_mut(&mut self.client).expect("outer Arc::get_mut") { + GenericRemoteStorage::AzureBlob(azure) => { + let azure = Arc::get_mut(azure).expect("inner Arc::get_mut"); + azure.timeout = timeout; + } + _ => unreachable!(), + } + } } enum MaybeEnabledStorage { @@ -213,6 +224,7 @@ fn create_azure_client( concurrency_limit: NonZeroUsize::new(100).unwrap(), max_keys_per_list_response, }), + timeout: Duration::from_secs(120), }; Ok(Arc::new( GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?, diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 3dc8347c83..e927b40e80 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -1,5 +1,6 @@ use std::env; use std::fmt::{Debug, Display}; +use std::future::Future; use std::num::NonZeroUsize; use std::ops::ControlFlow; use std::sync::Arc; @@ -9,9 +10,10 @@ use std::{collections::HashSet, time::SystemTime}; use crate::common::{download_to_vec, upload_stream}; use anyhow::Context; use camino::Utf8Path; -use futures_util::Future; +use futures_util::StreamExt; use remote_storage::{ - GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, S3Config, + DownloadError, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, + S3Config, }; use test_context::test_context; use test_context::AsyncTestContext; @@ -27,7 +29,6 @@ use common::{cleanup, ensure_logging_ready, upload_remote_data, upload_simple_re use utils::backoff; const ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_S3_REMOTE_STORAGE"; - const BASE_PREFIX: &str = "test"; #[test_context(MaybeEnabledStorage)] @@ -69,8 +70,11 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: ret } - async fn list_files(client: &Arc) -> anyhow::Result> { - Ok(retry(|| client.list_files(None, None)) + async fn list_files( + client: &Arc, + cancel: &CancellationToken, + ) -> anyhow::Result> { + Ok(retry(|| client.list_files(None, None, cancel)) .await .context("list root files failure")? .into_iter() @@ -90,11 +94,11 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: retry(|| { let (data, len) = upload_stream("remote blob data1".as_bytes().into()); - ctx.client.upload(data, len, &path1, None) + ctx.client.upload(data, len, &path1, None, &cancel) }) .await?; - let t0_files = list_files(&ctx.client).await?; + let t0_files = list_files(&ctx.client, &cancel).await?; let t0 = time_point().await; println!("at t0: {t0_files:?}"); @@ -102,17 +106,17 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: retry(|| { let (data, len) = upload_stream(old_data.as_bytes().into()); - ctx.client.upload(data, len, &path2, None) + ctx.client.upload(data, len, &path2, None, &cancel) }) .await?; - let t1_files = list_files(&ctx.client).await?; + let t1_files = list_files(&ctx.client, &cancel).await?; let t1 = time_point().await; println!("at t1: {t1_files:?}"); // A little check to ensure that our clock is not too far off from the S3 clock { - let dl = retry(|| ctx.client.download(&path2)).await?; + let dl = retry(|| ctx.client.download(&path2, &cancel)).await?; let last_modified = dl.last_modified.unwrap(); let half_wt = WAIT_TIME.mul_f32(0.5); let t0_hwt = t0 + half_wt; @@ -125,7 +129,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: retry(|| { let (data, len) = upload_stream("remote blob data3".as_bytes().into()); - ctx.client.upload(data, len, &path3, None) + ctx.client.upload(data, len, &path3, None, &cancel) }) .await?; @@ -133,12 +137,12 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: retry(|| { let (data, len) = upload_stream(new_data.as_bytes().into()); - ctx.client.upload(data, len, &path2, None) + ctx.client.upload(data, len, &path2, None, &cancel) }) .await?; - retry(|| ctx.client.delete(&path1)).await?; - let t2_files = list_files(&ctx.client).await?; + retry(|| ctx.client.delete(&path1, &cancel)).await?; + let t2_files = list_files(&ctx.client, &cancel).await?; let t2 = time_point().await; println!("at t2: {t2_files:?}"); @@ -147,10 +151,10 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: ctx.client .time_travel_recover(None, t2, t_final, &cancel) .await?; - let t2_files_recovered = list_files(&ctx.client).await?; + let t2_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t2: {t2_files_recovered:?}"); assert_eq!(t2_files, t2_files_recovered); - let path2_recovered_t2 = download_to_vec(ctx.client.download(&path2).await?).await?; + let path2_recovered_t2 = download_to_vec(ctx.client.download(&path2, &cancel).await?).await?; assert_eq!(path2_recovered_t2, new_data.as_bytes()); // after recovery to t1: path1 is back, path2 has the old content @@ -158,10 +162,10 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: ctx.client .time_travel_recover(None, t1, t_final, &cancel) .await?; - let t1_files_recovered = list_files(&ctx.client).await?; + let t1_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t1: {t1_files_recovered:?}"); assert_eq!(t1_files, t1_files_recovered); - let path2_recovered_t1 = download_to_vec(ctx.client.download(&path2).await?).await?; + let path2_recovered_t1 = download_to_vec(ctx.client.download(&path2, &cancel).await?).await?; assert_eq!(path2_recovered_t1, old_data.as_bytes()); // after recovery to t0: everything is gone except for path1 @@ -169,14 +173,14 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: ctx.client .time_travel_recover(None, t0, t_final, &cancel) .await?; - let t0_files_recovered = list_files(&ctx.client).await?; + let t0_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t0: {t0_files_recovered:?}"); assert_eq!(t0_files, t0_files_recovered); // cleanup let paths = &[path1, path2, path3]; - retry(|| ctx.client.delete_objects(paths)).await?; + retry(|| ctx.client.delete_objects(paths, &cancel)).await?; Ok(()) } @@ -197,6 +201,16 @@ impl EnabledS3 { base_prefix: BASE_PREFIX, } } + + fn configure_request_timeout(&mut self, timeout: Duration) { + match Arc::get_mut(&mut self.client).expect("outer Arc::get_mut") { + GenericRemoteStorage::AwsS3(s3) => { + let s3 = Arc::get_mut(s3).expect("inner Arc::get_mut"); + s3.timeout = timeout; + } + _ => unreachable!(), + } + } } enum MaybeEnabledStorage { @@ -370,8 +384,169 @@ fn create_s3_client( concurrency_limit: NonZeroUsize::new(100).unwrap(), max_keys_per_list_response, }), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }; Ok(Arc::new( GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?, )) } + +#[test_context(MaybeEnabledStorage)] +#[tokio::test] +async fn download_is_timeouted(ctx: &mut MaybeEnabledStorage) { + let MaybeEnabledStorage::Enabled(ctx) = ctx else { + return; + }; + + let cancel = CancellationToken::new(); + + let path = RemotePath::new(Utf8Path::new( + format!("{}/file_to_copy", ctx.base_prefix).as_str(), + )) + .unwrap(); + + let len = upload_large_enough_file(&ctx.client, &path, &cancel).await; + + let timeout = std::time::Duration::from_secs(5); + + ctx.configure_request_timeout(timeout); + + let started_at = std::time::Instant::now(); + let mut stream = ctx + .client + .download(&path, &cancel) + .await + .expect("download succeeds") + .download_stream; + + if started_at.elapsed().mul_f32(0.9) >= timeout { + tracing::warn!( + elapsed_ms = started_at.elapsed().as_millis(), + "timeout might be too low, consumed most of it during headers" + ); + } + + let first = stream + .next() + .await + .expect("should have the first blob") + .expect("should have succeeded"); + + tracing::info!(len = first.len(), "downloaded first chunk"); + + assert!( + first.len() < len, + "uploaded file is too small, we downloaded all on first chunk" + ); + + tokio::time::sleep(timeout).await; + + { + let started_at = std::time::Instant::now(); + let next = stream + .next() + .await + .expect("stream should not have ended yet"); + + tracing::info!( + next.is_err = next.is_err(), + elapsed_ms = started_at.elapsed().as_millis(), + "received item after timeout" + ); + + let e = next.expect_err("expected an error, but got a chunk?"); + + let inner = e.get_ref().expect("std::io::Error::inner should be set"); + assert!( + inner + .downcast_ref::() + .is_some_and(|e| matches!(e, DownloadError::Timeout)), + "{inner:?}" + ); + } + + ctx.configure_request_timeout(RemoteStorageConfig::DEFAULT_TIMEOUT); + + ctx.client.delete_objects(&[path], &cancel).await.unwrap() +} + +#[test_context(MaybeEnabledStorage)] +#[tokio::test] +async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) { + let MaybeEnabledStorage::Enabled(ctx) = ctx else { + return; + }; + + let cancel = CancellationToken::new(); + + let path = RemotePath::new(Utf8Path::new( + format!("{}/file_to_copy", ctx.base_prefix).as_str(), + )) + .unwrap(); + + let len = upload_large_enough_file(&ctx.client, &path, &cancel).await; + + { + let mut stream = ctx + .client + .download(&path, &cancel) + .await + .expect("download succeeds") + .download_stream; + + let first = stream + .next() + .await + .expect("should have the first blob") + .expect("should have succeeded"); + + tracing::info!(len = first.len(), "downloaded first chunk"); + + assert!( + first.len() < len, + "uploaded file is too small, we downloaded all on first chunk" + ); + + cancel.cancel(); + + let next = stream.next().await.expect("stream should have more"); + + let e = next.expect_err("expected an error, but got a chunk?"); + + let inner = e.get_ref().expect("std::io::Error::inner should be set"); + assert!( + inner + .downcast_ref::() + .is_some_and(|e| matches!(e, DownloadError::Cancelled)), + "{inner:?}" + ); + } + + let cancel = CancellationToken::new(); + + ctx.client.delete_objects(&[path], &cancel).await.unwrap(); +} + +/// Upload a long enough file so that we cannot download it in single chunk +/// +/// For s3 the first chunk seems to be less than 10kB, so this has a bit of a safety margin +async fn upload_large_enough_file( + client: &GenericRemoteStorage, + path: &RemotePath, + cancel: &CancellationToken, +) -> usize { + let header = bytes::Bytes::from_static("remote blob data content".as_bytes()); + let body = bytes::Bytes::from(vec![0u8; 1024]); + let contents = std::iter::once(header).chain(std::iter::repeat(body).take(128)); + + let len = contents.clone().fold(0, |acc, next| acc + next.len()); + + let contents = futures::stream::iter(contents.map(std::io::Result::Ok)); + + client + .upload(contents, len, path, None, cancel) + .await + .expect("upload succeeds"); + + len +} diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 1989bef817..6d71ff1dd4 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -1359,6 +1359,7 @@ broker_endpoint = '{broker_endpoint}' parsed_remote_storage_config, RemoteStorageConfig { storage: RemoteStorageKind::LocalFs(local_storage_path.clone()), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }, "Remote storage config should correctly parse the local FS config and fill other storage defaults" ); @@ -1426,6 +1427,7 @@ broker_endpoint = '{broker_endpoint}' concurrency_limit: s3_concurrency_limit, max_keys_per_list_response: None, }), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }, "Remote storage config should correctly parse the S3 config" ); diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index 81938b14b3..62ba702db7 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -867,6 +867,7 @@ mod test { let remote_fs_dir = harness.conf.workdir.join("remote_fs").canonicalize_utf8()?; let storage_config = RemoteStorageConfig { storage: RemoteStorageKind::LocalFs(remote_fs_dir.clone()), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }; let storage = GenericRemoteStorage::from_config(&storage_config).unwrap(); @@ -1170,6 +1171,7 @@ pub(crate) mod mock { pub struct ConsumerState { rx: tokio::sync::mpsc::UnboundedReceiver, executor_rx: tokio::sync::mpsc::Receiver, + cancel: CancellationToken, } impl ConsumerState { @@ -1183,7 +1185,7 @@ pub(crate) mod mock { match msg { DeleterMessage::Delete(objects) => { for path in objects { - match remote_storage.delete(&path).await { + match remote_storage.delete(&path, &self.cancel).await { Ok(_) => { debug!("Deleted {path}"); } @@ -1216,7 +1218,7 @@ pub(crate) mod mock { for path in objects { info!("Executing deletion {path}"); - match remote_storage.delete(&path).await { + match remote_storage.delete(&path, &self.cancel).await { Ok(_) => { debug!("Deleted {path}"); } @@ -1266,7 +1268,11 @@ pub(crate) mod mock { executor_tx, executed, remote_storage, - consumer: std::sync::Mutex::new(ConsumerState { rx, executor_rx }), + consumer: std::sync::Mutex::new(ConsumerState { + rx, + executor_rx, + cancel: CancellationToken::new(), + }), lsn_table: Arc::new(std::sync::RwLock::new(VisibleLsnUpdates::new())), } } diff --git a/pageserver/src/deletion_queue/deleter.rs b/pageserver/src/deletion_queue/deleter.rs index a75c73f2b1..1f04bc0410 100644 --- a/pageserver/src/deletion_queue/deleter.rs +++ b/pageserver/src/deletion_queue/deleter.rs @@ -8,6 +8,7 @@ use remote_storage::GenericRemoteStorage; use remote_storage::RemotePath; +use remote_storage::TimeoutOrCancel; use remote_storage::MAX_KEYS_PER_DELETE; use std::time::Duration; use tokio_util::sync::CancellationToken; @@ -71,9 +72,11 @@ impl Deleter { Err(anyhow::anyhow!("failpoint: deletion-queue-before-execute")) }); - self.remote_storage.delete_objects(&self.accumulator).await + self.remote_storage + .delete_objects(&self.accumulator, &self.cancel) + .await }, - |_| false, + TimeoutOrCancel::caused_by_cancel, 3, 10, "executing deletion batch", diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 88f4ae7086..e500a6123c 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -25,6 +25,7 @@ use pageserver_api::shard::ShardIdentity; use pageserver_api::shard::TenantShardId; use remote_storage::DownloadError; use remote_storage::GenericRemoteStorage; +use remote_storage::TimeoutOrCancel; use std::fmt; use storage_broker::BrokerClientChannel; use tokio::io::BufReader; @@ -3339,7 +3340,7 @@ impl Tenant { &self.cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x) } @@ -3389,8 +3390,10 @@ impl Tenant { ); let dest_path = &remote_initdb_archive_path(&self.tenant_shard_id.tenant_id, &timeline_id); + + // if this fails, it will get retried by retried control plane requests storage - .copy_object(source_path, dest_path) + .copy_object(source_path, dest_path, &self.cancel) .await .context("copy initdb tar")?; } @@ -4031,6 +4034,7 @@ pub(crate) mod harness { std::fs::create_dir_all(&remote_fs_dir).unwrap(); let config = RemoteStorageConfig { storage: RemoteStorageKind::LocalFs(remote_fs_dir.clone()), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }; let remote_storage = GenericRemoteStorage::from_config(&config).unwrap(); let deletion_queue = MockDeletionQueue::new(Some(remote_storage.clone())); diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index 0e192b577c..b64be8dcc5 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Context; use camino::{Utf8Path, Utf8PathBuf}; use pageserver_api::{models::TenantState, shard::TenantShardId}; -use remote_storage::{GenericRemoteStorage, RemotePath}; +use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel}; use tokio::sync::OwnedMutexGuard; use tokio_util::sync::CancellationToken; use tracing::{error, instrument, Instrument}; @@ -84,17 +84,17 @@ async fn create_remote_delete_mark( let data = bytes::Bytes::from_static(data); let stream = futures::stream::once(futures::future::ready(Ok(data))); remote_storage - .upload(stream, 0, &remote_mark_path, None) + .upload(stream, 0, &remote_mark_path, None, cancel) .await }, - |_e| false, + TimeoutOrCancel::caused_by_cancel, FAILED_UPLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, "mark_upload", cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x) .context("mark_upload")?; @@ -184,15 +184,15 @@ async fn remove_tenant_remote_delete_mark( if let Some(remote_storage) = remote_storage { let path = remote_tenant_delete_mark_path(conf, tenant_shard_id)?; backoff::retry( - || async { remote_storage.delete(&path).await }, - |_e| false, + || async { remote_storage.delete(&path, cancel).await }, + TimeoutOrCancel::caused_by_cancel, FAILED_UPLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, "remove_tenant_remote_delete_mark", cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x) .context("remove_tenant_remote_delete_mark")?; } diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 483f53d5c8..91e1179e53 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -196,14 +196,12 @@ pub(crate) use upload::upload_initdb_dir; use utils::backoff::{ self, exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, }; -use utils::timeout::{timeout_cancellable, TimeoutCancellableError}; use std::collections::{HashMap, VecDeque}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; -use std::time::Duration; -use remote_storage::{DownloadError, GenericRemoteStorage, RemotePath}; +use remote_storage::{DownloadError, GenericRemoteStorage, RemotePath, TimeoutOrCancel}; use std::ops::DerefMut; use tracing::{debug, error, info, instrument, warn}; use tracing::{info_span, Instrument}; @@ -263,11 +261,6 @@ pub(crate) const INITDB_PRESERVED_PATH: &str = "initdb-preserved.tar.zst"; /// Default buffer size when interfacing with [`tokio::fs::File`]. pub(crate) const BUFFER_SIZE: usize = 32 * 1024; -/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not -/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that. -pub(crate) const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120); -pub(crate) const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120); - pub enum MaybeDeletedIndexPart { IndexPart(IndexPart), Deleted(IndexPart), @@ -331,40 +324,6 @@ pub struct RemoteTimelineClient { cancel: CancellationToken, } -/// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to anyhow. -/// -/// This is a convenience for the various upload functions. In future -/// the anyhow::Error result should be replaced with a more structured type that -/// enables callers to avoid handling shutdown as an error. -async fn upload_cancellable(cancel: &CancellationToken, future: F) -> anyhow::Result<()> -where - F: std::future::Future>, -{ - match timeout_cancellable(UPLOAD_TIMEOUT, cancel, future).await { - Ok(Ok(())) => Ok(()), - Ok(Err(e)) => Err(e), - Err(TimeoutCancellableError::Timeout) => Err(anyhow::anyhow!("Timeout")), - Err(TimeoutCancellableError::Cancelled) => Err(anyhow::anyhow!("Shutting down")), - } -} -/// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to DownloaDError. -async fn download_cancellable( - cancel: &CancellationToken, - future: F, -) -> Result -where - F: std::future::Future>, -{ - match timeout_cancellable(DOWNLOAD_TIMEOUT, cancel, future).await { - Ok(Ok(r)) => Ok(r), - Ok(Err(e)) => Err(e), - Err(TimeoutCancellableError::Timeout) => { - Err(DownloadError::Other(anyhow::anyhow!("Timed out"))) - } - Err(TimeoutCancellableError::Cancelled) => Err(DownloadError::Cancelled), - } -} - impl RemoteTimelineClient { /// /// Create a remote storage client for given timeline @@ -1050,7 +1009,7 @@ impl RemoteTimelineClient { &self.cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x)?; // all good, disarm the guard and mark as success @@ -1082,14 +1041,14 @@ impl RemoteTimelineClient { upload::preserve_initdb_archive(&self.storage_impl, tenant_id, timeline_id, cancel) .await }, - |_e| false, + TimeoutOrCancel::caused_by_cancel, FAILED_DOWNLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, "preserve_initdb_tar_zst", &cancel.clone(), ) .await - .ok_or_else(|| anyhow::anyhow!("Cancellled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x) .context("backing up initdb archive")?; Ok(()) @@ -1151,7 +1110,7 @@ impl RemoteTimelineClient { let remaining = download_retry( || async { self.storage_impl - .list_files(Some(&timeline_storage_path), None) + .list_files(Some(&timeline_storage_path), None, &cancel) .await }, "list remaining files", @@ -1445,6 +1404,10 @@ impl RemoteTimelineClient { Ok(()) => { break; } + Err(e) if TimeoutOrCancel::caused_by_cancel(&e) => { + // loop around to do the proper stopping + continue; + } Err(e) => { let retries = task.retries.fetch_add(1, Ordering::SeqCst); diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index e755cd08f3..43f5e6c182 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -11,16 +11,14 @@ use camino::{Utf8Path, Utf8PathBuf}; use pageserver_api::shard::TenantShardId; use tokio::fs::{self, File, OpenOptions}; use tokio::io::{AsyncSeekExt, AsyncWriteExt}; +use tokio_util::io::StreamReader; use tokio_util::sync::CancellationToken; use tracing::warn; -use utils::timeout::timeout_cancellable; use utils::{backoff, crashsafe}; use crate::config::PageServerConf; use crate::span::debug_assert_current_span_has_tenant_and_timeline_id; -use crate::tenant::remote_timeline_client::{ - download_cancellable, remote_layer_path, remote_timelines_path, DOWNLOAD_TIMEOUT, -}; +use crate::tenant::remote_timeline_client::{remote_layer_path, remote_timelines_path}; use crate::tenant::storage_layer::LayerFileName; use crate::tenant::Generation; use crate::virtual_file::on_fatal_io_error; @@ -83,15 +81,13 @@ pub async fn download_layer_file<'a>( .with_context(|| format!("create a destination file for layer '{temp_file_path}'")) .map_err(DownloadError::Other)?; - // Cancellation safety: it is safe to cancel this future, because it isn't writing to a local - // file: the write to local file doesn't start until after the request header is returned - // and we start draining the body stream below - let download = download_cancellable(cancel, storage.download(&remote_path)) + let download = storage + .download(&remote_path, cancel) .await .with_context(|| { format!( - "open a download stream for layer with remote storage path '{remote_path:?}'" - ) + "open a download stream for layer with remote storage path '{remote_path:?}'" + ) }) .map_err(DownloadError::Other)?; @@ -100,43 +96,26 @@ pub async fn download_layer_file<'a>( let mut reader = tokio_util::io::StreamReader::new(download.download_stream); - // Cancellation safety: it is safe to cancel this future because it is writing into a temporary file, - // and we will unlink the temporary file if there is an error. This unlink is important because we - // are in a retry loop, and we wouldn't want to leave behind a rogue write I/O to a file that - // we will imminiently try and write to again. - let bytes_amount: u64 = match timeout_cancellable( - DOWNLOAD_TIMEOUT, - cancel, - tokio::io::copy_buf(&mut reader, &mut destination_file), - ) - .await - .with_context(|| { - format!( + let bytes_amount = tokio::io::copy_buf(&mut reader, &mut destination_file) + .await + .with_context(|| format!( "download layer at remote path '{remote_path:?}' into file {temp_file_path:?}" - ) - }) - .map_err(DownloadError::Other)? - { - Ok(b) => Ok(b), + )) + .map_err(DownloadError::Other); + + match bytes_amount { + Ok(bytes_amount) => { + let destination_file = destination_file.into_inner(); + Ok((destination_file, bytes_amount)) + } Err(e) => { - // Remove incomplete files: on restart Timeline would do this anyway, but we must - // do it here for the retry case. if let Err(e) = tokio::fs::remove_file(&temp_file_path).await { on_fatal_io_error(&e, &format!("Removing temporary file {temp_file_path}")); } + Err(e) } } - .with_context(|| { - format!( - "download layer at remote path '{remote_path:?}' into file {temp_file_path:?}" - ) - }) - .map_err(DownloadError::Other)?; - - let destination_file = destination_file.into_inner(); - - Ok((destination_file, bytes_amount)) }, &format!("download {remote_path:?}"), cancel, @@ -218,9 +197,11 @@ pub async fn list_remote_timelines( let listing = download_retry_forever( || { - download_cancellable( + storage.list( + Some(&remote_path), + ListingMode::WithDelimiter, + None, &cancel, - storage.list(Some(&remote_path), ListingMode::WithDelimiter, None), ) }, &format!("list timelines for {tenant_shard_id}"), @@ -259,26 +240,23 @@ async fn do_download_index_part( index_generation: Generation, cancel: &CancellationToken, ) -> Result { - use futures::stream::StreamExt; - let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation); let index_part_bytes = download_retry_forever( || async { - // Cancellation: if is safe to cancel this future because we're just downloading into - // a memory buffer, not touching local disk. - let index_part_download = - download_cancellable(cancel, storage.download(&remote_path)).await?; + let download = storage.download(&remote_path, cancel).await?; - let mut index_part_bytes = Vec::new(); - let mut stream = std::pin::pin!(index_part_download.download_stream); - while let Some(chunk) = stream.next().await { - let chunk = chunk - .with_context(|| format!("download index part at {remote_path:?}")) - .map_err(DownloadError::Other)?; - index_part_bytes.extend_from_slice(&chunk[..]); - } - Ok(index_part_bytes) + let mut bytes = Vec::new(); + + let stream = download.download_stream; + let mut stream = StreamReader::new(stream); + + tokio::io::copy_buf(&mut stream, &mut bytes) + .await + .with_context(|| format!("download index part at {remote_path:?}")) + .map_err(DownloadError::Other)?; + + Ok(bytes) }, &format!("download {remote_path:?}"), cancel, @@ -373,7 +351,7 @@ pub(super) async fn download_index_part( let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none()); let indices = download_retry( - || async { storage.list_files(Some(&index_prefix), None).await }, + || async { storage.list_files(Some(&index_prefix), None, cancel).await }, "list index_part files", cancel, ) @@ -446,11 +424,10 @@ pub(crate) async fn download_initdb_tar_zst( .with_context(|| format!("tempfile creation {temp_path}")) .map_err(DownloadError::Other)?; - let download = match download_cancellable(cancel, storage.download(&remote_path)).await - { + let download = match storage.download(&remote_path, cancel).await { Ok(dl) => dl, Err(DownloadError::NotFound) => { - download_cancellable(cancel, storage.download(&remote_preserved_path)).await? + storage.download(&remote_preserved_path, cancel).await? } Err(other) => Err(other)?, }; @@ -460,6 +437,7 @@ pub(crate) async fn download_initdb_tar_zst( // TODO: this consumption of the response body should be subject to timeout + cancellation, but // not without thinking carefully about how to recover safely from cancelling a write to // local storage (e.g. by writing into a temp file as we do in download_layer) + // FIXME: flip the weird error wrapping tokio::io::copy_buf(&mut download, &mut writer) .await .with_context(|| format!("download initdb.tar.zst at {remote_path:?}")) diff --git a/pageserver/src/tenant/remote_timeline_client/upload.rs b/pageserver/src/tenant/remote_timeline_client/upload.rs index c17e27b446..137fe48b73 100644 --- a/pageserver/src/tenant/remote_timeline_client/upload.rs +++ b/pageserver/src/tenant/remote_timeline_client/upload.rs @@ -16,7 +16,7 @@ use crate::{ config::PageServerConf, tenant::remote_timeline_client::{ index::IndexPart, remote_index_path, remote_initdb_archive_path, - remote_initdb_preserved_archive_path, remote_path, upload_cancellable, + remote_initdb_preserved_archive_path, remote_path, }, }; use remote_storage::{GenericRemoteStorage, TimeTravelError}; @@ -49,16 +49,15 @@ pub(crate) async fn upload_index_part<'a>( let index_part_bytes = bytes::Bytes::from(index_part_bytes); let remote_path = remote_index_path(tenant_shard_id, timeline_id, generation); - upload_cancellable( - cancel, - storage.upload_storage_object( + storage + .upload_storage_object( futures::stream::once(futures::future::ready(Ok(index_part_bytes))), index_part_size, &remote_path, - ), - ) - .await - .with_context(|| format!("upload index part for '{tenant_shard_id} / {timeline_id}'")) + cancel, + ) + .await + .with_context(|| format!("upload index part for '{tenant_shard_id} / {timeline_id}'")) } /// Attempts to upload given layer files. @@ -115,11 +114,10 @@ pub(super) async fn upload_timeline_layer<'a>( let reader = tokio_util::io::ReaderStream::with_capacity(source_file, super::BUFFER_SIZE); - upload_cancellable(cancel, storage.upload(reader, fs_size, &storage_path, None)) + storage + .upload(reader, fs_size, &storage_path, None, cancel) .await - .with_context(|| format!("upload layer from local path '{source_path}'"))?; - - Ok(()) + .with_context(|| format!("upload layer from local path '{source_path}'")) } /// Uploads the given `initdb` data to the remote storage. @@ -139,12 +137,10 @@ pub(crate) async fn upload_initdb_dir( let file = tokio_util::io::ReaderStream::with_capacity(initdb_tar_zst, super::BUFFER_SIZE); let remote_path = remote_initdb_archive_path(tenant_id, timeline_id); - upload_cancellable( - cancel, - storage.upload_storage_object(file, size as usize, &remote_path), - ) - .await - .with_context(|| format!("upload initdb dir for '{tenant_id} / {timeline_id}'")) + storage + .upload_storage_object(file, size as usize, &remote_path, cancel) + .await + .with_context(|| format!("upload initdb dir for '{tenant_id} / {timeline_id}'")) } pub(crate) async fn preserve_initdb_archive( @@ -155,7 +151,8 @@ pub(crate) async fn preserve_initdb_archive( ) -> anyhow::Result<()> { let source_path = remote_initdb_archive_path(tenant_id, timeline_id); let dest_path = remote_initdb_preserved_archive_path(tenant_id, timeline_id); - upload_cancellable(cancel, storage.copy_object(&source_path, &dest_path)) + storage + .copy_object(&source_path, &dest_path, cancel) .await .with_context(|| format!("backing up initdb archive for '{tenant_id} / {timeline_id}'")) } diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index c23416a7f0..6966cf7709 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -523,12 +523,13 @@ impl<'a> TenantDownloader<'a> { tracing::debug!("Downloading heatmap for secondary tenant",); let heatmap_path = remote_heatmap_path(tenant_shard_id); + let cancel = &self.secondary_state.cancel; let heatmap_bytes = backoff::retry( || async { let download = self .remote_storage - .download(&heatmap_path) + .download(&heatmap_path, cancel) .await .map_err(UpdateError::from)?; let mut heatmap_bytes = Vec::new(); @@ -540,7 +541,7 @@ impl<'a> TenantDownloader<'a> { FAILED_DOWNLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, "download heatmap", - &self.secondary_state.cancel, + cancel, ) .await .ok_or_else(|| UpdateError::Cancelled) diff --git a/pageserver/src/tenant/secondary/heatmap_uploader.rs b/pageserver/src/tenant/secondary/heatmap_uploader.rs index 806e3fb0e8..660459a733 100644 --- a/pageserver/src/tenant/secondary/heatmap_uploader.rs +++ b/pageserver/src/tenant/secondary/heatmap_uploader.rs @@ -21,18 +21,17 @@ use futures::Future; use md5; use pageserver_api::shard::TenantShardId; use rand::Rng; -use remote_storage::GenericRemoteStorage; +use remote_storage::{GenericRemoteStorage, TimeoutOrCancel}; use super::{ + heatmap::HeatMapTenant, scheduler::{self, JobGenerator, RunningJob, SchedulingResult, TenantBackgroundJobs}, - CommandRequest, + CommandRequest, UploadCommand, }; use tokio_util::sync::CancellationToken; use tracing::{info_span, instrument, Instrument}; use utils::{backoff, completion::Barrier, yielding_loop::yielding_loop}; -use super::{heatmap::HeatMapTenant, UploadCommand}; - pub(super) async fn heatmap_uploader_task( tenant_manager: Arc, remote_storage: GenericRemoteStorage, @@ -417,10 +416,10 @@ async fn upload_tenant_heatmap( || async { let bytes = futures::stream::once(futures::future::ready(Ok(bytes.clone()))); remote_storage - .upload_storage_object(bytes, size, &path) + .upload_storage_object(bytes, size, &path, cancel) .await }, - |_| false, + TimeoutOrCancel::caused_by_cancel, 3, u32::MAX, "Uploading heatmap", diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index ad22829183..d941445c2d 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -13,7 +13,7 @@ use parquet::{ }, record::RecordWriter, }; -use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig}; +use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel}; use tokio::{sync::mpsc, time}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, Span}; @@ -314,20 +314,23 @@ async fn upload_parquet( let path = RemotePath::from_string(&format!( "{year:04}/{month:02}/{day:02}/{hour:02}/requests_{id}.parquet" ))?; + let cancel = CancellationToken::new(); backoff::retry( || async { let stream = futures::stream::once(futures::future::ready(Ok(data.clone()))); - storage.upload(stream, data.len(), &path, None).await + storage + .upload(stream, data.len(), &path, None, &cancel) + .await }, - |_e| false, + TimeoutOrCancel::caused_by_cancel, FAILED_UPLOAD_WARN_THRESHOLD, FAILED_UPLOAD_MAX_RETRIES, "request_data_upload", // we don't want cancellation to interrupt here, so we make a dummy cancel token - &CancellationToken::new(), + &cancel, ) .await - .ok_or_else(|| anyhow::anyhow!("Cancelled")) + .ok_or_else(|| anyhow::Error::new(TimeoutOrCancel::Cancel)) .and_then(|x| x) .context("request_data_upload")?; @@ -413,7 +416,8 @@ mod tests { ) .unwrap(), max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, - }) + }), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, }) ); assert_eq!(parquet_upload.parquet_upload_row_group_size, 100); @@ -466,6 +470,7 @@ mod tests { ) -> Vec<(u64, usize, i64)> { let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::LocalFs(tmpdir.to_path_buf()), + timeout: std::time::Duration::from_secs(120), }; let storage = GenericRemoteStorage::from_config(&remote_storage_config).unwrap(); diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index dbdc742d26..944d80f777 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -511,7 +511,11 @@ async fn backup_object( let file = tokio_util::io::ReaderStream::with_capacity(file, BUFFER_SIZE); - storage.upload_storage_object(file, size, target_file).await + let cancel = CancellationToken::new(); + + storage + .upload_storage_object(file, size, target_file, &cancel) + .await } pub async fn read_object( @@ -526,8 +530,10 @@ pub async fn read_object( info!("segment download about to start from remote path {file_path:?} at offset {offset}"); + let cancel = CancellationToken::new(); + let download = storage - .download_storage_object(Some((offset, None)), file_path) + .download_storage_object(Some((offset, None)), file_path, &cancel) .await .with_context(|| { format!("Failed to open WAL segment download stream for remote path {file_path:?}") @@ -559,7 +565,8 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> { // Note: listing segments might take a long time if there are many of them. // We don't currently have http requests timeout cancellation, but if/once // we have listing should get streaming interface to make progress. - let token = CancellationToken::new(); // not really used + + let cancel = CancellationToken::new(); // not really used backoff::retry( || async { // Do list-delete in batch_size batches to make progress even if there a lot of files. @@ -567,7 +574,7 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> { // I'm not sure deleting while iterating is expected in s3. loop { let files = storage - .list_files(Some(&remote_path), Some(batch_size)) + .list_files(Some(&remote_path), Some(batch_size), &cancel) .await?; if files.is_empty() { return Ok(()); // done @@ -580,14 +587,15 @@ pub async fn delete_timeline(ttid: &TenantTimelineId) -> Result<()> { files.first().unwrap().object_name().unwrap_or(""), files.last().unwrap().object_name().unwrap_or("") ); - storage.delete_objects(&files).await?; + storage.delete_objects(&files, &cancel).await?; } }, + // consider TimeoutOrCancel::caused_by_cancel when using cancellation |_| false, 3, 10, "executing WAL segments deletion batch", - &token, + &cancel, ) .await .ok_or_else(|| anyhow::anyhow!("canceled")) @@ -617,7 +625,12 @@ pub async fn copy_s3_segments( let remote_path = RemotePath::new(&relative_dst_path)?; - let files = storage.list_files(Some(&remote_path), None).await?; + let cancel = CancellationToken::new(); + + let files = storage + .list_files(Some(&remote_path), None, &cancel) + .await?; + let uploaded_segments = &files .iter() .filter_map(|file| file.object_name().map(ToOwned::to_owned)) @@ -645,7 +658,7 @@ pub async fn copy_s3_segments( let from = RemotePath::new(&relative_src_path.join(&segment_name))?; let to = RemotePath::new(&relative_dst_path.join(&segment_name))?; - storage.copy_object(&from, &to).await?; + storage.copy_object(&from, &to, &cancel).await?; } info!( From 5fa747e493bbbcc6878c03742c5a63622ec31165 Mon Sep 17 00:00:00 2001 From: John Spray Date: Thu, 15 Feb 2024 08:21:53 +0000 Subject: [PATCH 75/81] pageserver: shard splitting refinements (parent deletion, hard linking) (#6725) ## Problem - We weren't deleting parent shard contents once the split was done - Re-downloading layers into child shards is wasteful ## Summary of changes - Hard-link layers into child chart local storage during split - Delete parent shards content at the end --------- Co-authored-by: Joonas Koivunen --- pageserver/src/tenant/mgr.rs | 154 ++++++++++++++++++++++++++- test_runner/regress/test_sharding.py | 15 +++ 2 files changed, 165 insertions(+), 4 deletions(-) diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 9aee39bd35..7260080720 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -2,6 +2,7 @@ //! page server. use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf}; +use futures::stream::StreamExt; use itertools::Itertools; use pageserver_api::key::Key; use pageserver_api::models::ShardParameters; @@ -1439,8 +1440,10 @@ impl TenantManager { } }; - // TODO: hardlink layers from the parent into the child shard directories so that they don't immediately re-download - // TODO: erase the dentries from the parent + // Optimization: hardlink layers from the parent into the children, so that they don't have to + // re-download & duplicate the data referenced in their initial IndexPart + self.shard_split_hardlink(parent, child_shards.clone()) + .await?; // Take a snapshot of where the parent's WAL ingest had got to: we will wait for // child shards to reach this point. @@ -1479,10 +1482,11 @@ impl TenantManager { // Phase 4: wait for child chards WAL ingest to catch up to target LSN for child_shard_id in &child_shards { + let child_shard_id = *child_shard_id; let child_shard = { let locked = TENANTS.read().unwrap(); let peek_slot = - tenant_map_peek_slot(&locked, child_shard_id, TenantSlotPeekMode::Read)?; + tenant_map_peek_slot(&locked, &child_shard_id, TenantSlotPeekMode::Read)?; peek_slot.and_then(|s| s.get_attached()).cloned() }; if let Some(t) = child_shard { @@ -1517,7 +1521,7 @@ impl TenantManager { } } - // Phase 5: Shut down the parent shard. + // Phase 5: Shut down the parent shard, and erase it from disk let (_guard, progress) = completion::channel(); match parent.shutdown(progress, false).await { Ok(()) => {} @@ -1525,6 +1529,24 @@ impl TenantManager { other.wait().await; } } + let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id); + let tmp_path = safe_rename_tenant_dir(&local_tenant_directory) + .await + .with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?; + task_mgr::spawn( + task_mgr::BACKGROUND_RUNTIME.handle(), + TaskKind::MgmtRequest, + None, + None, + "tenant_files_delete", + false, + async move { + fs::remove_dir_all(tmp_path.as_path()) + .await + .with_context(|| format!("tenant directory {:?} deletion", tmp_path)) + }, + ); + parent_slot_guard.drop_old_value()?; // Phase 6: Release the InProgress on the parent shard @@ -1532,6 +1554,130 @@ impl TenantManager { Ok(child_shards) } + + /// Part of [`Self::shard_split`]: hard link parent shard layers into child shards, as an optimization + /// to avoid the children downloading them again. + /// + /// For each resident layer in the parent shard, we will hard link it into all of the child shards. + async fn shard_split_hardlink( + &self, + parent_shard: &Tenant, + child_shards: Vec, + ) -> anyhow::Result<()> { + debug_assert_current_span_has_tenant_id(); + + let parent_path = self.conf.tenant_path(parent_shard.get_tenant_shard_id()); + let (parent_timelines, parent_layers) = { + let mut parent_layers = Vec::new(); + let timelines = parent_shard.timelines.lock().unwrap().clone(); + let parent_timelines = timelines.keys().cloned().collect::>(); + for timeline in timelines.values() { + let timeline_layers = timeline + .layers + .read() + .await + .resident_layers() + .collect::>() + .await; + for layer in timeline_layers { + let relative_path = layer + .local_path() + .strip_prefix(&parent_path) + .context("Removing prefix from parent layer path")?; + parent_layers.push(relative_path.to_owned()); + } + } + debug_assert!( + !parent_layers.is_empty(), + "shutdown cannot empty the layermap" + ); + (parent_timelines, parent_layers) + }; + + let mut child_prefixes = Vec::new(); + let mut create_dirs = Vec::new(); + + for child in child_shards { + let child_prefix = self.conf.tenant_path(&child); + create_dirs.push(child_prefix.clone()); + create_dirs.extend( + parent_timelines + .iter() + .map(|t| self.conf.timeline_path(&child, t)), + ); + + child_prefixes.push(child_prefix); + } + + // Since we will do a large number of small filesystem metadata operations, batch them into + // spawn_blocking calls rather than doing each one as a tokio::fs round-trip. + let jh = tokio::task::spawn_blocking(move || -> anyhow::Result { + for dir in &create_dirs { + if let Err(e) = std::fs::create_dir_all(dir) { + // Ignore AlreadyExists errors, drop out on all other errors + match e.kind() { + std::io::ErrorKind::AlreadyExists => {} + _ => { + return Err(anyhow::anyhow!(e).context(format!("Creating {dir}"))); + } + } + } + } + + for child_prefix in child_prefixes { + for relative_layer in &parent_layers { + let parent_path = parent_path.join(relative_layer); + let child_path = child_prefix.join(relative_layer); + if let Err(e) = std::fs::hard_link(&parent_path, &child_path) { + match e.kind() { + std::io::ErrorKind::AlreadyExists => {} + std::io::ErrorKind::NotFound => { + tracing::info!( + "Layer {} not found during hard-linking, evicted during split?", + relative_layer + ); + } + _ => { + return Err(anyhow::anyhow!(e).context(format!( + "Hard linking {relative_layer} into {child_prefix}" + ))) + } + } + } + } + } + + // Durability is not required for correctness, but if we crashed during split and + // then came restarted with empty timeline dirs, it would be very inefficient to + // re-populate from remote storage. + for dir in create_dirs { + if let Err(e) = crashsafe::fsync(&dir) { + // Something removed a newly created timeline dir out from underneath us? Extremely + // unexpected, but not worth panic'ing over as this whole function is just an + // optimization. + tracing::warn!("Failed to fsync directory {dir}: {e}") + } + } + + Ok(parent_layers.len()) + }); + + match jh.await { + Ok(Ok(layer_count)) => { + tracing::info!(count = layer_count, "Hard linked layers into child shards"); + } + Ok(Err(e)) => { + // This is an optimization, so we tolerate failure. + tracing::warn!("Error hard-linking layers, proceeding anyway: {e}") + } + Err(e) => { + // This is something totally unexpected like a panic, so bail out. + anyhow::bail!("Error joining hard linking task: {e}"); + } + } + + Ok(()) + } } #[derive(Debug, thiserror::Error)] diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index fa40219d0e..fcf4b9f72a 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -194,6 +194,18 @@ def test_sharding_split_smoke( assert len(pre_split_pageserver_ids) == 4 + def shards_on_disk(shard_ids): + for pageserver in env.pageservers: + for shard_id in shard_ids: + if pageserver.tenant_dir(shard_id).exists(): + return True + + return False + + old_shard_ids = [TenantShardId(tenant_id, i, shard_count) for i in range(0, shard_count)] + # Before split, old shards exist + assert shards_on_disk(old_shard_ids) + env.attachment_service.tenant_shard_split(tenant_id, shard_count=split_shard_count) post_split_pageserver_ids = [loc["node_id"] for loc in env.attachment_service.locate(tenant_id)] @@ -202,6 +214,9 @@ def test_sharding_split_smoke( assert len(set(post_split_pageserver_ids)) == shard_count assert set(post_split_pageserver_ids) == set(pre_split_pageserver_ids) + # The old parent shards should no longer exist on disk + assert not shards_on_disk(old_shard_ids) + workload.validate() workload.churn_rows(256) From 1af047dd3ee9eed0de955b61c295142a95a3fde4 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Thu, 15 Feb 2024 14:34:19 +0200 Subject: [PATCH 76/81] Fix typo in CI message (#6749) --- .github/workflows/build_and_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 6e4020a1b8..c53cbada7d 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -253,7 +253,7 @@ jobs: done if [ "${FAILED}" = "true" ]; then - echo >&2 "Please update vendors/revisions.json if these changes are intentional" + echo >&2 "Please update vendor/revisions.json if these changes are intentional" exit 1 fi From 936f2ee2a59af86a76df29f0fd6693d1a61da0f7 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Thu, 15 Feb 2024 15:48:44 +0200 Subject: [PATCH 77/81] fix: accidential wide span in tests (#6772) introduced in a PR without other #[tracing::instrument] changes. --- pageserver/src/tenant.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index e500a6123c..fdf04244c3 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -3276,7 +3276,7 @@ impl Tenant { /// For unit tests, make this visible so that other modules can directly create timelines #[cfg(test)] - #[tracing::instrument(fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), %timeline_id))] + #[tracing::instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), %timeline_id))] pub(crate) async fn bootstrap_timeline_test( &self, timeline_id: TimelineId, From 9ad940086cebd02041142117a76914bc5120c060 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Thu, 15 Feb 2024 09:59:13 -0500 Subject: [PATCH 78/81] fix superuser permission check for extensions (#6733) close https://github.com/neondatabase/neon/issues/6236 This pull request bumps neon postgres dependencies. The corresponding postgres commits fix the checks for superuser permission when creating an extension. Also, for creating native functinos, it now allows neon_superuser only in the extension creation process. --------- Signed-off-by: Alex Chi Z Co-authored-by: Heikki Linnakangas --- vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/revisions.json | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 9dd9956c55..b4bae26a0f 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 9dd9956c55ffbbd9abe77d10382453757fedfcf5 +Subproject commit b4bae26a0f09c69e979e6cb55780398e3102e022 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index ca2def9993..9eef016e18 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit ca2def999368d9df098a637234ad5a9003189463 +Subproject commit 9eef016e18bf61753e3cbaa755f705db6a4f7b1d diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 9c37a49884..f7b63d8cf9 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 9c37a4988463a97d9cacb321acf3828b09823269 +Subproject commit f7b63d8cf9ae040f6907c3c13ef25fcf15a36161 diff --git a/vendor/revisions.json b/vendor/revisions.json index 72bc0d7e0d..37ca812c4a 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { - "postgres-v16": "9c37a4988463a97d9cacb321acf3828b09823269", - "postgres-v15": "ca2def999368d9df098a637234ad5a9003189463", - "postgres-v14": "9dd9956c55ffbbd9abe77d10382453757fedfcf5" + "postgres-v16": "f7b63d8cf9ae040f6907c3c13ef25fcf15a36161", + "postgres-v15": "9eef016e18bf61753e3cbaa755f705db6a4f7b1d", + "postgres-v14": "b4bae26a0f09c69e979e6cb55780398e3102e022" } From cd3e4ac18d1f6998325855d0f9b7b194a10676cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Thu, 15 Feb 2024 16:14:51 +0100 Subject: [PATCH 79/81] Rename TEST_IMG function to test_img (#6762) Latter follows the canonical way to naming functions in Rust. --- pageserver/src/tenant.rs | 64 ++++++++++++++++++------------------- pageserver/src/walingest.rs | 54 +++++++++++++++---------------- 2 files changed, 58 insertions(+), 60 deletions(-) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index fdf04244c3..ced4bb5af4 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -3933,8 +3933,7 @@ pub(crate) mod harness { TimelineId::from_array(hex!("AA223344556677881122334455667788")); /// Convenience function to create a page image with given string as the only content - #[allow(non_snake_case)] - pub fn TEST_IMG(s: &str) -> Bytes { + pub fn test_img(s: &str) -> Bytes { let mut buf = BytesMut::new(); buf.extend_from_slice(s.as_bytes()); buf.resize(64, 0); @@ -4179,7 +4178,6 @@ pub(crate) mod harness { _pg_version: u32, ) -> anyhow::Result { let records_neon = records.iter().all(|r| apply_neon::can_apply_in_neon(&r.1)); - if records_neon { // For Neon wal records, we can decode without spawning postgres, so do so. let base_img = base_img.expect("Neon WAL redo requires base image").1; @@ -4204,7 +4202,7 @@ pub(crate) mod harness { ); println!("{s}"); - Ok(TEST_IMG(&s)) + Ok(test_img(&s)) } } } @@ -4239,7 +4237,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x10), - &Value::Image(TEST_IMG("foo at 0x10")), + &Value::Image(test_img("foo at 0x10")), &ctx, ) .await?; @@ -4251,7 +4249,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x20), - &Value::Image(TEST_IMG("foo at 0x20")), + &Value::Image(test_img("foo at 0x20")), &ctx, ) .await?; @@ -4260,15 +4258,15 @@ mod tests { assert_eq!( tline.get(*TEST_KEY, Lsn(0x10), &ctx).await?, - TEST_IMG("foo at 0x10") + test_img("foo at 0x10") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x1f), &ctx).await?, - TEST_IMG("foo at 0x10") + test_img("foo at 0x10") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x20), &ctx).await?, - TEST_IMG("foo at 0x20") + test_img("foo at 0x20") ); Ok(()) @@ -4384,7 +4382,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(TEST_IMG(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {}", lsn))), ctx, ) .await?; @@ -4394,7 +4392,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(TEST_IMG(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {}", lsn))), ctx, ) .await?; @@ -4408,7 +4406,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(TEST_IMG(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {}", lsn))), ctx, ) .await?; @@ -4418,7 +4416,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(TEST_IMG(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {}", lsn))), ctx, ) .await?; @@ -4573,7 +4571,7 @@ mod tests { // Broken, as long as you don't need to access data from the parent. assert_eq!( newtline.get(*TEST_KEY, Lsn(0x70), &ctx).await?, - TEST_IMG(&format!("foo at {}", Lsn(0x70))) + test_img(&format!("foo at {}", Lsn(0x70))) ); // This needs to traverse to the parent, and fails. @@ -4650,7 +4648,7 @@ mod tests { // Check that the data is still accessible on the branch. assert_eq!( newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?, - TEST_IMG(&format!("foo at {}", Lsn(0x40))) + test_img(&format!("foo at {}", Lsn(0x40))) ); Ok(()) @@ -4825,7 +4823,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x10), - &Value::Image(TEST_IMG("foo at 0x10")), + &Value::Image(test_img("foo at 0x10")), &ctx, ) .await?; @@ -4842,7 +4840,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x20), - &Value::Image(TEST_IMG("foo at 0x20")), + &Value::Image(test_img("foo at 0x20")), &ctx, ) .await?; @@ -4859,7 +4857,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x30), - &Value::Image(TEST_IMG("foo at 0x30")), + &Value::Image(test_img("foo at 0x30")), &ctx, ) .await?; @@ -4876,7 +4874,7 @@ mod tests { .put( *TEST_KEY, Lsn(0x40), - &Value::Image(TEST_IMG("foo at 0x40")), + &Value::Image(test_img("foo at 0x40")), &ctx, ) .await?; @@ -4890,23 +4888,23 @@ mod tests { assert_eq!( tline.get(*TEST_KEY, Lsn(0x10), &ctx).await?, - TEST_IMG("foo at 0x10") + test_img("foo at 0x10") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x1f), &ctx).await?, - TEST_IMG("foo at 0x10") + test_img("foo at 0x10") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x20), &ctx).await?, - TEST_IMG("foo at 0x20") + test_img("foo at 0x20") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x30), &ctx).await?, - TEST_IMG("foo at 0x30") + test_img("foo at 0x30") ); assert_eq!( tline.get(*TEST_KEY, Lsn(0x40), &ctx).await?, - TEST_IMG("foo at 0x40") + test_img("foo at 0x40") ); Ok(()) @@ -4938,7 +4936,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), &ctx, ) .await?; @@ -5000,7 +4998,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), &ctx, ) .await?; @@ -5021,7 +5019,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), &ctx, ) .await?; @@ -5035,7 +5033,7 @@ mod tests { test_key.field6 = blknum as u32; assert_eq!( tline.get(test_key, lsn, &ctx).await?, - TEST_IMG(&format!("{} at {}", blknum, last_lsn)) + test_img(&format!("{} at {}", blknum, last_lsn)) ); } @@ -5089,7 +5087,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), &ctx, ) .await?; @@ -5118,7 +5116,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), &ctx, ) .await?; @@ -5133,7 +5131,7 @@ mod tests { test_key.field6 = blknum as u32; assert_eq!( tline.get(test_key, lsn, &ctx).await?, - TEST_IMG(&format!("{} at {}", blknum, last_lsn)) + test_img(&format!("{} at {}", blknum, last_lsn)) ); } @@ -5195,7 +5193,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(TEST_IMG(&format!("{} {} at {}", idx, blknum, lsn))), + &Value::Image(test_img(&format!("{} {} at {}", idx, blknum, lsn))), &ctx, ) .await?; @@ -5217,7 +5215,7 @@ mod tests { test_key.field6 = blknum as u32; assert_eq!( tline.get(test_key, *lsn, &ctx).await?, - TEST_IMG(&format!("{idx} {blknum} at {lsn}")) + test_img(&format!("{idx} {blknum} at {lsn}")) ); } } diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 12ceac0191..8df2f1713a 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -1695,22 +1695,22 @@ mod tests { let mut m = tline.begin_modification(Lsn(0x20)); walingest.put_rel_creation(&mut m, TESTREL_A, &ctx).await?; walingest - .put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 2"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 2"), &ctx) .await?; m.commit(&ctx).await?; let mut m = tline.begin_modification(Lsn(0x30)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 3"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 3"), &ctx) .await?; m.commit(&ctx).await?; let mut m = tline.begin_modification(Lsn(0x40)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 1, TEST_IMG("foo blk 1 at 4"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 1, test_img("foo blk 1 at 4"), &ctx) .await?; m.commit(&ctx).await?; let mut m = tline.begin_modification(Lsn(0x50)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 2, TEST_IMG("foo blk 2 at 5"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 2, test_img("foo blk 2 at 5"), &ctx) .await?; m.commit(&ctx).await?; @@ -1751,46 +1751,46 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x20)), false, &ctx) .await?, - TEST_IMG("foo blk 0 at 2") + test_img("foo blk 0 at 2") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x30)), false, &ctx) .await?, - TEST_IMG("foo blk 0 at 3") + test_img("foo blk 0 at 3") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x40)), false, &ctx) .await?, - TEST_IMG("foo blk 0 at 3") + test_img("foo blk 0 at 3") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x40)), false, &ctx) .await?, - TEST_IMG("foo blk 1 at 4") + test_img("foo blk 1 at 4") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x50)), false, &ctx) .await?, - TEST_IMG("foo blk 0 at 3") + test_img("foo blk 0 at 3") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x50)), false, &ctx) .await?, - TEST_IMG("foo blk 1 at 4") + test_img("foo blk 1 at 4") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx) .await?, - TEST_IMG("foo blk 2 at 5") + test_img("foo blk 2 at 5") ); // Truncate last block @@ -1812,13 +1812,13 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x60)), false, &ctx) .await?, - TEST_IMG("foo blk 0 at 3") + test_img("foo blk 0 at 3") ); assert_eq!( tline .get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x60)), false, &ctx) .await?, - TEST_IMG("foo blk 1 at 4") + test_img("foo blk 1 at 4") ); // should still see the truncated block with older LSN @@ -1832,7 +1832,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx) .await?, - TEST_IMG("foo blk 2 at 5") + test_img("foo blk 2 at 5") ); // Truncate to zero length @@ -1851,7 +1851,7 @@ mod tests { // Extend from 0 to 2 blocks, leaving a gap let mut m = tline.begin_modification(Lsn(0x70)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 1, TEST_IMG("foo blk 1"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 1, test_img("foo blk 1"), &ctx) .await?; m.commit(&ctx).await?; assert_eq!( @@ -1870,13 +1870,13 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x70)), false, &ctx) .await?, - TEST_IMG("foo blk 1") + test_img("foo blk 1") ); // Extend a lot more, leaving a big gap that spans across segments let mut m = tline.begin_modification(Lsn(0x80)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 1500, TEST_IMG("foo blk 1500"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 1500, test_img("foo blk 1500"), &ctx) .await?; m.commit(&ctx).await?; assert_eq!( @@ -1897,7 +1897,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, 1500, Version::Lsn(Lsn(0x80)), false, &ctx) .await?, - TEST_IMG("foo blk 1500") + test_img("foo blk 1500") ); Ok(()) @@ -1915,7 +1915,7 @@ mod tests { let mut m = tline.begin_modification(Lsn(0x20)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 2"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 2"), &ctx) .await?; m.commit(&ctx).await?; @@ -1952,7 +1952,7 @@ mod tests { // Re-create it let mut m = tline.begin_modification(Lsn(0x40)); walingest - .put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 4"), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, 0, test_img("foo blk 0 at 4"), &ctx) .await?; m.commit(&ctx).await?; @@ -1990,7 +1990,7 @@ mod tests { for blkno in 0..relsize { let data = format!("foo blk {} at {}", blkno, Lsn(0x20)); walingest - .put_rel_page_image(&mut m, TESTREL_A, blkno, TEST_IMG(&data), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, blkno, test_img(&data), &ctx) .await?; } m.commit(&ctx).await?; @@ -2028,7 +2028,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(lsn), false, &ctx) .await?, - TEST_IMG(&data) + test_img(&data) ); } @@ -2055,7 +2055,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x60)), false, &ctx) .await?, - TEST_IMG(&data) + test_img(&data) ); } @@ -2073,7 +2073,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x50)), false, &ctx) .await?, - TEST_IMG(&data) + test_img(&data) ); } @@ -2084,7 +2084,7 @@ mod tests { for blkno in 0..relsize { let data = format!("foo blk {} at {}", blkno, lsn); walingest - .put_rel_page_image(&mut m, TESTREL_A, blkno, TEST_IMG(&data), &ctx) + .put_rel_page_image(&mut m, TESTREL_A, blkno, test_img(&data), &ctx) .await?; } m.commit(&ctx).await?; @@ -2109,7 +2109,7 @@ mod tests { tline .get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x80)), false, &ctx) .await?, - TEST_IMG(&data) + test_img(&data) ); } @@ -2130,7 +2130,7 @@ mod tests { for blknum in 0..RELSEG_SIZE + 1 { lsn += 0x10; let mut m = tline.begin_modification(Lsn(lsn)); - let img = TEST_IMG(&format!("foo blk {} at {}", blknum, Lsn(lsn))); + let img = test_img(&format!("foo blk {} at {}", blknum, Lsn(lsn))); walingest .put_rel_page_image(&mut m, TESTREL_A, blknum as BlockNumber, img, &ctx) .await?; From c72cb44213e1ffeccaa321d2d43a90c7fa9c8881 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Thu, 15 Feb 2024 15:53:58 +0000 Subject: [PATCH 80/81] test_runner/performance: parametrize benchmarks (#6744) ## Problem Currently, we don't store `PLATFORM` for Nightly Benchmarks. It causes them to be merged as reruns in Allure report (because they have the same test name). ## Summary of changes - Parametrize benchmarks by - Postgres Version (14/15/16) - Build Type (debug/release/remote) - PLATFORM (neon-staging/github-actions-selfhosted/...) --------- Co-authored-by: Bodobolero --- test_runner/fixtures/parametrize.py | 51 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index d8ac92abb6..57ca1932b0 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -2,57 +2,58 @@ import os from typing import Optional import pytest -from _pytest.fixtures import FixtureRequest from _pytest.python import Metafunc from fixtures.pg_version import PgVersion """ -Dynamically parametrize tests by Postgres version, build type (debug/release/remote), and possibly by other parameters +Dynamically parametrize tests by different parameters """ @pytest.fixture(scope="function", autouse=True) -def pg_version(request: FixtureRequest) -> Optional[PgVersion]: - # Do not parametrize performance tests yet, we need to prepare grafana charts first - if "test_runner/performance" in str(request.node.path): - v = os.environ.get("DEFAULT_PG_VERSION") - return PgVersion(v) - +def pg_version() -> Optional[PgVersion]: return None @pytest.fixture(scope="function", autouse=True) -def build_type(request: FixtureRequest) -> Optional[str]: - # Do not parametrize performance tests yet, we need to prepare grafana charts first - if "test_runner/performance" in str(request.node.path): - return os.environ.get("BUILD_TYPE", "").lower() - +def build_type() -> Optional[str]: return None @pytest.fixture(scope="function", autouse=True) -def pageserver_virtual_file_io_engine(request: FixtureRequest) -> Optional[str]: +def platform() -> Optional[str]: + return None + + +@pytest.fixture(scope="function", autouse=True) +def pageserver_virtual_file_io_engine() -> Optional[str]: return None def pytest_generate_tests(metafunc: Metafunc): - if (v := os.environ.get("DEFAULT_PG_VERSION")) is None: - pg_versions = [version for version in PgVersion if version != PgVersion.NOT_SET] - else: - pg_versions = [PgVersion(v)] - - if (bt := os.environ.get("BUILD_TYPE")) is None: + if (bt := os.getenv("BUILD_TYPE")) is None: build_types = ["debug", "release"] else: build_types = [bt.lower()] - # Do not parametrize performance tests yet by Postgres version or build type, we need to prepare grafana charts first - if "test_runner/performance" not in metafunc.definition._nodeid: - metafunc.parametrize("build_type", build_types) - metafunc.parametrize("pg_version", pg_versions, ids=map(lambda v: f"pg{v}", pg_versions)) + metafunc.parametrize("build_type", build_types) + + if (v := os.getenv("DEFAULT_PG_VERSION")) is None: + pg_versions = [version for version in PgVersion if version != PgVersion.NOT_SET] + else: + pg_versions = [PgVersion(v)] + + metafunc.parametrize("pg_version", pg_versions, ids=map(lambda v: f"pg{v}", pg_versions)) # A hacky way to parametrize tests only for `pageserver_virtual_file_io_engine=tokio-epoll-uring` # And do not change test name for default `pageserver_virtual_file_io_engine=std-fs` to keep tests statistics - if (io_engine := os.environ.get("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "")) not in ("", "std-fs"): + if (io_engine := os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "")) not in ("", "std-fs"): metafunc.parametrize("pageserver_virtual_file_io_engine", [io_engine]) + + # For performance tests, parametrize also by platform + if ( + "test_runner/performance" in metafunc.definition._nodeid + and (platform := os.getenv("PLATFORM")) is not None + ): + metafunc.parametrize("platform", [platform.lower()]) From 046d9c69e6734c8e60b6da91d3fb5dd4983001f2 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Thu, 15 Feb 2024 18:58:26 +0200 Subject: [PATCH 81/81] fix: require wider jwt for changing the io engine (#6770) io-engine should not be changeable with any JWT token, for example the tenant_id scoped token which computes have. --- pageserver/src/http/routes.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index ab546c873a..df3794f222 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -1951,6 +1951,7 @@ async fn put_io_engine_handler( mut r: Request, _cancel: CancellationToken, ) -> Result, ApiError> { + check_permission(&r, None)?; let kind: crate::virtual_file::IoEngineKind = json_request(&mut r).await?; crate::virtual_file::io_engine::set(kind); json_response(StatusCode::OK, ())