diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..c8fd1209de --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,6 @@ + +blank_issues_enabled: true +contact_links: + - name: Feature request + url: https://console.neon.tech/app/projects?modal=feedback + about: For feature requests in the Neon product, please submit via the feedback form on `https://console.neon.tech` diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 6c2cee0971..4008cd0d36 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -71,7 +71,7 @@ runs: if: inputs.build_type != 'remote' uses: ./.github/actions/download with: - name: compatibility-snapshot-${{ inputs.build_type }}-pg${{ inputs.pg_version }} + name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }} path: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }} prefix: latest # The lack of compatibility snapshot (for example, for the new Postgres version) @@ -211,13 +211,13 @@ runs: fi - name: Upload compatibility snapshot - if: github.ref_name == 'release' + # Note, that we use `github.base_ref` which is a target branch for a PR + if: github.event_name == 'pull_request' && github.base_ref == 'release' uses: ./.github/actions/upload with: - name: compatibility-snapshot-${{ inputs.build_type }}-pg${{ inputs.pg_version }}-${{ github.run_id }} + name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }} # Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/ - prefix: latest - name: Upload test results if: ${{ !cancelled() }} diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 5e9fff0e6a..e18e6a1201 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -216,8 +216,14 @@ jobs: #nextest does not yet support running doctests ${cov_prefix} cargo test --doc $CARGO_FLAGS $CARGO_FEATURES + # run all non-pageserver tests + ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E '!package(pageserver)' + + # run pageserver tests with different settings for io_engine in std-fs tokio-epoll-uring ; do - NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES + for io_buffer_alignment in 0 1 512 ; do + NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT=$io_buffer_alignment ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)' + done done # Run separate tests for real S3 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 1e7f3598c2..53d33b420f 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1055,43 +1055,88 @@ jobs: generate_release_notes: true, }) + # The job runs on `release` branch and copies compatibility data and Neon artifact from the last *release PR* to the latest directory promote-compatibility-data: - needs: [ check-permissions, promote-images, tag, build-and-test-locally ] + needs: [ deploy ] if: github.ref_name == 'release' - runs-on: [ self-hosted, small ] - container: - image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned - options: --init + runs-on: ubuntu-22.04 steps: - - name: Promote compatibility snapshot for the release + - name: Fetch GITHUB_RUN_ID and COMMIT_SHA for the last merged release PR + id: fetch-last-release-pr-info + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + branch_name_and_pr_number=$(gh pr list \ + --repo "${GITHUB_REPOSITORY}" \ + --base release \ + --state merged \ + --limit 10 \ + --json mergeCommit,headRefName,number \ + --jq ".[] | select(.mergeCommit.oid==\"${GITHUB_SHA}\") | { branch_name: .headRefName, pr_number: .number }") + branch_name=$(echo "${branch_name_and_pr_number}" | jq -r '.branch_name') + pr_number=$(echo "${branch_name_and_pr_number}" | jq -r '.pr_number') + + run_id=$(gh run list \ + --repo "${GITHUB_REPOSITORY}" \ + --workflow build_and_test.yml \ + --branch "${branch_name}" \ + --json databaseId \ + --limit 1 \ + --jq '.[].databaseId') + + last_commit_sha=$(gh pr view "${pr_number}" \ + --repo "${GITHUB_REPOSITORY}" \ + --json commits \ + --jq '.commits[-1].oid') + + echo "run-id=${run_id}" | tee -a ${GITHUB_OUTPUT} + echo "commit-sha=${last_commit_sha}" | tee -a ${GITHUB_OUTPUT} + + - name: Promote compatibility snapshot and Neon artifact env: BUCKET: neon-github-public-dev - PREFIX: artifacts/latest - COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }} + AWS_REGION: eu-central-1 + COMMIT_SHA: ${{ steps.fetch-last-release-pr-info.outputs.commit-sha }} + RUN_ID: ${{ steps.fetch-last-release-pr-info.outputs.run-id }} run: | - # Update compatibility snapshot for the release - for pg_version in v14 v15 v16; do - for build_type in debug release; do - OLD_FILENAME=compatibility-snapshot-${build_type}-pg${pg_version}-${GITHUB_RUN_ID}.tar.zst - NEW_FILENAME=compatibility-snapshot-${build_type}-pg${pg_version}.tar.zst + old_prefix="artifacts/${COMMIT_SHA}/${RUN_ID}" + new_prefix="artifacts/latest" - time aws s3 mv --only-show-errors s3://${BUCKET}/${PREFIX}/${OLD_FILENAME} s3://${BUCKET}/${PREFIX}/${NEW_FILENAME} + files_to_promote=() + files_on_s3=$(aws s3api list-objects-v2 --bucket ${BUCKET} --prefix ${old_prefix} | jq -r '.Contents[]?.Key' || true) + + for arch in X64 ARM64; do + for build_type in debug release; do + neon_artifact_filename="neon-Linux-${arch}-${build_type}-artifact.tar.zst" + s3_key=$(echo "${files_on_s3}" | grep ${neon_artifact_filename} | sort --version-sort | tail -1 || true) + if [ -z "${s3_key}" ]; then + echo >&2 "Neither s3://${BUCKET}/${old_prefix}/${neon_artifact_filename} nor its version from previous attempts exist" + exit 1 + fi + + files_to_promote+=("s3://${BUCKET}/${s3_key}") + + for pg_version in v14 v15 v16; do + # We run less tests for debug builds, so we don't need to promote them + if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v16" ] ; }; then + continue + fi + + compatibility_data_filename="compatibility-snapshot-${arch}-${build_type}-pg${pg_version}.tar.zst" + s3_key=$(echo "${files_on_s3}" | grep ${compatibility_data_filename} | sort --version-sort | tail -1 || true) + if [ -z "${s3_key}" ]; then + echo >&2 "Neither s3://${BUCKET}/${old_prefix}/${compatibility_data_filename} nor its version from previous attempts exist" + exit 1 + fi + + files_to_promote+=("s3://${BUCKET}/${s3_key}") + done done done - # Update Neon artifact for the release (reuse already uploaded artifact) - for build_type in debug release; do - OLD_PREFIX=artifacts/${COMMIT_SHA}/${GITHUB_RUN_ID} - FILENAME=neon-${{ runner.os }}-${{ runner.arch }}-${build_type}-artifact.tar.zst - - S3_KEY=$(aws s3api list-objects-v2 --bucket ${BUCKET} --prefix ${OLD_PREFIX} | jq -r '.Contents[]?.Key' | grep ${FILENAME} | sort --version-sort | tail -1 || true) - if [ -z "${S3_KEY}" ]; then - echo >&2 "Neither s3://${BUCKET}/${OLD_PREFIX}/${FILENAME} nor its version from previous attempts exist" - exit 1 - fi - - time aws s3 cp --only-show-errors s3://${BUCKET}/${S3_KEY} s3://${BUCKET}/${PREFIX}/${FILENAME} + for f in "${files_to_promote[@]}"; do + time aws s3 cp --only-show-errors ${f} s3://${BUCKET}/${new_prefix}/ done pin-build-tools-image: diff --git a/Cargo.lock b/Cargo.lock index 441ca1ff86..5af3ef3804 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -936,6 +936,12 @@ dependencies = [ "which", ] +[[package]] +name = "bit_field" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" + [[package]] name = "bitflags" version = "1.3.2" @@ -1327,7 +1333,6 @@ name = "control_plane" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", "camino", "clap", "comfy-table", @@ -2944,17 +2949,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" -[[package]] -name = "leaky-bucket" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eb491abd89e9794d50f93c8db610a29509123e3fbbc9c8c67a528e9391cd853" -dependencies = [ - "parking_lot 0.12.1", - "tokio", - "tracing", -] - [[package]] name = "libc" version = "0.2.150" @@ -3683,6 +3677,7 @@ dependencies = [ "async-compression", "async-stream", "async-trait", + "bit_field", "byteorder", "bytes", "camino", @@ -3707,7 +3702,6 @@ dependencies = [ "humantime-serde", "hyper 0.14.26", "itertools 0.10.5", - "leaky-bucket", "md5", "metrics", "nix 0.27.1", @@ -3732,6 +3726,7 @@ dependencies = [ "reqwest 0.12.4", "rpds", "scopeguard", + "send-future", "serde", "serde_json", "serde_path_to_error", @@ -3794,7 +3789,6 @@ name = "pageserver_client" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", "bytes", "futures", "pageserver_api", @@ -5455,6 +5449,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +[[package]] +name = "send-future" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224e328af6e080cddbab3c770b1cf50f0351ba0577091ef2410c3951d835ff87" + [[package]] name = "sentry" version = "0.32.3" @@ -5950,7 +5950,6 @@ name = "storage_controller_client" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", "bytes", "futures", "pageserver_api", @@ -6953,7 +6952,6 @@ dependencies = [ "anyhow", "arc-swap", "async-compression", - "async-trait", "bincode", "byteorder", "bytes", @@ -6969,7 +6967,6 @@ dependencies = [ "humantime", "hyper 0.14.26", "jsonwebtoken", - "leaky-bucket", "metrics", "nix 0.27.1", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index e038c0b4ff..fa949f9757 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ axum = { version = "0.6.20", features = ["ws"] } base64 = "0.13.0" bincode = "1.3" bindgen = "0.65" +bit_field = "0.10.2" bstr = "1.0" byteorder = "1.4" bytes = "1.0" @@ -107,7 +108,6 @@ ipnet = "2.9.0" itertools = "0.10" jsonwebtoken = "9" lasso = "0.7" -leaky-bucket = "1.0.1" libc = "0.2" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } @@ -145,6 +145,7 @@ rustls-split = "0.3" scopeguard = "1.1" sysinfo = "0.29.2" sd-notify = "0.4.1" +send-future = "0.1.0" sentry = { version = "0.32", default-features = false, features = ["backtrace", "contexts", "panic", "rustls", "reqwest" ] } serde = { version = "1.0", features = ["derive"] } serde_json = "1" diff --git a/Dockerfile.compute-node b/Dockerfile.compute-node index 7acaf2f2fd..b6c89cd71f 100644 --- a/Dockerfile.compute-node +++ b/Dockerfile.compute-node @@ -942,7 +942,7 @@ COPY --from=hll-pg-build /hll.tar.gz /ext-src COPY --from=plpgsql-check-pg-build /plpgsql_check.tar.gz /ext-src #COPY --from=timescaledb-pg-build /timescaledb.tar.gz /ext-src COPY --from=pg-hint-plan-pg-build /pg_hint_plan.tar.gz /ext-src -COPY patches/pg_hintplan.patch /ext-src +COPY patches/pg_hint_plan.patch /ext-src COPY --from=pg-cron-pg-build /pg_cron.tar.gz /ext-src COPY patches/pg_cron.patch /ext-src #COPY --from=pg-pgx-ulid-build /home/nonroot/pgx_ulid.tar.gz /ext-src @@ -964,7 +964,7 @@ RUN cd /ext-src/pgvector-src && patch -p1 <../pgvector.patch RUN cd /ext-src/rum-src && patch -p1 <../rum.patch # cmake is required for the h3 test RUN apt-get update && apt-get install -y cmake -RUN patch -p1 < /ext-src/pg_hintplan.patch +RUN cd /ext-src/pg_hint_plan-src && patch -p1 < /ext-src/pg_hint_plan.patch COPY --chmod=755 docker-compose/run-tests.sh /run-tests.sh RUN patch -p1 ) { + let (tenant_id, timeline_id, lsn) = { + let state = compute.state.lock().unwrap(); + let spec = state.pspec.as_ref().expect("Spec must be set"); + match spec.spec.mode { + ComputeMode::Static(lsn) => (spec.tenant_id, spec.timeline_id, lsn), + _ => return, + } + }; + let compute = compute.clone(); + + let span = tracing::info_span!("lsn_lease_bg_task", %tenant_id, %timeline_id, %lsn); + thread::spawn(move || { + let _entered = span.entered(); + if let Err(e) = lsn_lease_bg_task(compute, tenant_id, timeline_id, lsn) { + // TODO: might need stronger error feedback than logging an warning. + warn!("Exited with error: {e}"); + } + }); +} + +/// Renews lsn lease periodically so static compute are not affected by GC. +fn lsn_lease_bg_task( + compute: Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + lsn: Lsn, +) -> Result<()> { + loop { + let valid_until = acquire_lsn_lease_with_retry(&compute, tenant_id, timeline_id, lsn)?; + let valid_duration = valid_until + .duration_since(SystemTime::now()) + .unwrap_or(Duration::ZERO); + + // Sleep for 60 seconds less than the valid duration but no more than half of the valid duration. + let sleep_duration = valid_duration + .saturating_sub(Duration::from_secs(60)) + .max(valid_duration / 2); + + info!( + "Succeeded, sleeping for {} seconds", + sleep_duration.as_secs() + ); + thread::sleep(sleep_duration); + } +} + +/// Acquires lsn lease in a retry loop. Returns the expiration time if a lease is granted. +/// Returns an error if a lease is explicitly not granted. Otherwise, we keep sending requests. +fn acquire_lsn_lease_with_retry( + compute: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + lsn: Lsn, +) -> Result { + let mut attempts = 0usize; + let mut retry_period_ms: f64 = 500.0; + const MAX_RETRY_PERIOD_MS: f64 = 60.0 * 1000.0; + + loop { + // Note: List of pageservers is dynamic, need to re-read configs before each attempt. + let configs = { + let state = compute.state.lock().unwrap(); + + let spec = state.pspec.as_ref().expect("spec must be set"); + + let conn_strings = spec.pageserver_connstr.split(','); + + conn_strings + .map(|connstr| { + let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr"); + if let Some(storage_auth_token) = &spec.storage_auth_token { + info!("Got storage auth token from spec file"); + config.password(storage_auth_token.clone()); + } else { + info!("Storage auth token not set"); + } + config + }) + .collect::>() + }; + + let result = try_acquire_lsn_lease(tenant_id, timeline_id, lsn, &configs); + match result { + Ok(Some(res)) => { + return Ok(res); + } + Ok(None) => { + bail!("Permanent error: lease could not be obtained, LSN is behind the GC cutoff"); + } + Err(e) => { + warn!("Failed to acquire lsn lease: {e} (attempt {attempts}"); + + thread::sleep(Duration::from_millis(retry_period_ms as u64)); + retry_period_ms *= 1.5; + retry_period_ms = retry_period_ms.min(MAX_RETRY_PERIOD_MS); + } + } + attempts += 1; + } +} + +/// Tries to acquire an LSN lease through PS page_service API. +fn try_acquire_lsn_lease( + tenant_id: TenantId, + timeline_id: TimelineId, + lsn: Lsn, + configs: &[postgres::Config], +) -> Result> { + fn get_valid_until( + config: &postgres::Config, + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + lsn: Lsn, + ) -> Result> { + let mut client = config.connect(NoTls)?; + let cmd = format!("lease lsn {} {} {} ", tenant_shard_id, timeline_id, lsn); + let res = client.simple_query(&cmd)?; + let msg = match res.first() { + Some(msg) => msg, + None => bail!("empty response"), + }; + let row = match msg { + SimpleQueryMessage::Row(row) => row, + _ => bail!("error parsing lsn lease response"), + }; + + // Note: this will be None if a lease is explicitly not granted. + let valid_until_str = row.get("valid_until"); + + let valid_until = valid_until_str.map(|s| { + SystemTime::UNIX_EPOCH + .checked_add(Duration::from_millis(u128::from_str(s).unwrap() as u64)) + .expect("Time larger than max SystemTime could handle") + }); + Ok(valid_until) + } + + let shard_count = configs.len(); + + let valid_until = if shard_count > 1 { + configs + .iter() + .enumerate() + .map(|(shard_number, config)| { + let tenant_shard_id = TenantShardId { + tenant_id, + shard_count: ShardCount::new(shard_count as u8), + shard_number: ShardNumber(shard_number as u8), + }; + get_valid_until(config, tenant_shard_id, timeline_id, lsn) + }) + .collect::>>>()? + .into_iter() + .min() + .unwrap() + } else { + get_valid_until( + &configs[0], + TenantShardId::unsharded(tenant_id), + timeline_id, + lsn, + )? + }; + + Ok(valid_until) +} diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 487ac8f047..6fca59b368 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -6,7 +6,6 @@ license.workspace = true [dependencies] anyhow.workspace = true -async-trait.workspace = true camino.workspace = true clap.workspace = true comfy-table.workspace = true diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index a0a73f5609..573f1688d5 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -5,6 +5,7 @@ //! ```text //! .neon/safekeepers/ //! ``` +use std::future::Future; use std::io::Write; use std::path::PathBuf; use std::time::Duration; @@ -34,12 +35,10 @@ pub enum SafekeeperHttpError { type Result = result::Result; -#[async_trait::async_trait] -pub trait ResponseErrorMessageExt: Sized { - async fn error_from_body(self) -> Result; +pub(crate) trait ResponseErrorMessageExt: Sized { + fn error_from_body(self) -> impl Future> + Send; } -#[async_trait::async_trait] impl ResponseErrorMessageExt for reqwest::Response { async fn error_from_body(self) -> Result { let status = self.status(); diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index 35510ccbca..5cce6cf3ae 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -41,6 +41,8 @@ enum Command { listen_http_addr: String, #[arg(long)] listen_http_port: u16, + #[arg(long)] + availability_zone_id: String, }, /// Modify a node's configuration in the storage controller @@ -322,6 +324,7 @@ async fn main() -> anyhow::Result<()> { listen_pg_port, listen_http_addr, listen_http_port, + availability_zone_id, } => { storcon_client .dispatch::<_, ()>( @@ -333,6 +336,7 @@ async fn main() -> anyhow::Result<()> { listen_pg_port, listen_http_addr, listen_http_port, + availability_zone_id: Some(availability_zone_id), }), ) .await?; diff --git a/docker-compose/run-tests.sh b/docker-compose/run-tests.sh index 58b2581197..3fc0b90071 100644 --- a/docker-compose/run-tests.sh +++ b/docker-compose/run-tests.sh @@ -3,7 +3,7 @@ set -x cd /ext-src || exit 2 FAILED= -LIST=$( (echo "${SKIP//","/"\n"}"; ls -d -- *-src) | sort | uniq -u) +LIST=$( (echo -e "${SKIP//","/"\n"}"; ls -d -- *-src) | sort | uniq -u) for d in ${LIST} do [ -d "${d}" ] || continue diff --git a/docs/rfcs/037-storage-controller-restarts.md b/docs/rfcs/037-storage-controller-restarts.md new file mode 100644 index 0000000000..bad422344f --- /dev/null +++ b/docs/rfcs/037-storage-controller-restarts.md @@ -0,0 +1,259 @@ +# Rolling Storage Controller Restarts + +## Summary + +This RFC describes the issues around the current storage controller restart procedure +and describes an implementation which reduces downtime to a few milliseconds on the happy path. + +## Motivation + +Storage controller upgrades (restarts, more generally) can cause multi-second availability gaps. +While the storage controller does not sit on the main data path, it's generally not acceptable +to block management requests for extended periods of time (e.g. https://github.com/neondatabase/neon/issues/8034). + +### Current Implementation + +The storage controller runs in a Kubernetes Deployment configured for one replica and strategy set to [Recreate](https://kubernetes.io/docs/concepts/workloads/controllers/deployment/#recreate-deployment). +In non Kubernetes terms, during an upgrade, the currently running storage controller is stopped and, only after, +a new instance is created. + +At start-up, the storage controller calls into all the pageservers it manages (retrieved from DB) to learn the +latest locations of all tenant shards present on them. This is usually fast, but can push into tens of seconds +under unfavourable circumstances: pageservers are heavily loaded or unavailable. + +## Prior Art + +There's probably as many ways of handling restarts gracefully as there are distributed systems. Some examples include: +* Active/Standby architectures: Two or more instance of the same service run, but traffic is only routed to one of them. +For fail-over, traffic is routed to one of the standbys (which becomes active). +* Consensus Algorithms (Raft, Paxos and friends): The part of consensus we care about here is leader election: peers communicate to each other +and use a voting scheme that ensures the existence of a single leader (e.g. Raft epochs). + +## Requirements + +* Reduce storage controller unavailability during upgrades to milliseconds +* Minimize the interval in which it's possible for more than one storage controller +to issue reconciles. +* Have one uniform implementation for restarts and upgrades +* Fit in with the current Kubernetes deployment scheme + +## Non Goals + +* Implement our own consensus algorithm from scratch +* Completely eliminate downtime storage controller downtime. Instead we aim to reduce it to the point where it looks +like a transient error to the control plane + +## Impacted Components + +* storage controller +* deployment orchestration (i.e. Ansible) +* helm charts + +## Terminology + +* Observed State: in-memory mapping between tenant shards and their current pageserver locations - currently built up +at start-up by quering pageservers +* Deployment: Kubernetes [primitive](https://kubernetes.io/docs/concepts/workloads/controllers/deployment/) that models +a set of replicas + +## Implementation + +### High Level Flow + +At a very high level the proposed idea is to start a new storage controller instance while +the previous one is still running and cut-over to it when it becomes ready. The new instance, +should coordinate with the existing one and transition responsibility gracefully. While the controller +has built in safety against split-brain situations (via generation numbers), we'd like to avoid such +scenarios since they can lead to availability issues for tenants that underwent changes while two controllers +were operating at the same time and require operator intervention to remedy. + +### Kubernetes Deployment Configuration + +On the Kubernetes configuration side, the proposal is to update the storage controller `Deployment` +to use `spec.strategy.type = RollingUpdate`, `spec.strategy.rollingUpdate.maxSurge=1` and `spec.strategy.maxUnavailable=0`. +Under the hood, Kubernetes creates a new replica set and adds one pod to it (`maxSurge=1`). The old replica set does not +scale down until the new replica set has one replica in the ready state (`maxUnavailable=0`). + +The various possible failure scenarios are investigated in the [Handling Failures](#handling-failures) section. + +### Storage Controller Start-Up + +This section describes the primitives required on the storage controller side and the flow of the happy path. + +#### Database Table For Leader Synchronization + +A new table should be added to the storage controller database for leader synchronization during startup. +This table will always contain at most one row. The proposed name for the table is `leader` and the schema +contains two elements: +* `hostname`: represents the hostname for the current storage controller leader - should be addressible +from other pods in the deployment +* `start_timestamp`: holds the start timestamp for the current storage controller leader (UTC timezone) - only required +for failure case handling: see [Previous Leader Crashes Before New Leader Readiness](#previous-leader-crashes-before-new-leader-readiness) + +Storage controllers will read the leader row at start-up and then update it to mark themselves as the leader +at the end of the start-up sequence. We want compare-and-exchange semantics for the update: avoid the +situation where two concurrent updates succeed and overwrite each other. The default Postgres isolation +level is `READ COMMITTED`, which isn't strict enough here. This update transaction should use at least `REPEATABLE +READ` isolation level in order to [prevent lost updates](https://www.interdb.jp/pg/pgsql05/08.html). Currently, +the storage controller uses the stricter `SERIALIZABLE` isolation level for all transactions. This more than suits +our needs here. + +``` +START TRANSACTION ISOLATION LEVEL REPEATABLE READ +UPDATE leader SET hostname=, start_timestamp= +WHERE hostname=, start_timestampt=; +``` + +If the transaction fails or if no rows have been updated, then the compare-and-exchange is regarded as a failure. + +#### Step Down API + +A new HTTP endpoint should be added to the storage controller: `POST /control/v1/step_down`. Upon receiving this +request the leader cancels any pending reconciles and goes into a mode where it replies with 503 to all other APIs +and does not issue any location configurations to its pageservers. The successful HTTP response will return a serialized +snapshot of the observed state. + +If other step down requests come in after the initial one, the request is handled and the observed state is returned (required +for failure scenario handling - see [Handling Failures](#handling-failures)). + +#### Graceful Restart Happy Path + +At start-up, the first thing the storage controller does is retrieve the sole row from the new +`leader` table. If such an entry exists, send a `/step_down` PUT API call to the current leader. +This should be retried a few times with a short backoff (see [1]). The aspiring leader loads the +observed state into memory and the start-up sequence proceeds as usual, but *without* querying the +pageservers in order to build up the observed state. + +Before doing any reconciliations or persistence change, update the `leader` database table as described in the [Database Table For Leader Synchronization](database-table-for-leader-synchronization) +section. If this step fails, the storage controller process exits. + +Note that no row will exist in the `leaders` table for the first graceful restart. In that case, force update the `leader` table +(without the WHERE clause) and perform with the pre-existing start-up procedure (i.e. build observed state by querying pageservers). + +Summary of proposed new start-up sequence: +1. Call `/step_down` +2. Perform any pending database migrations +3. Load state from database +4. Load observed state returned in step (1) into memory +5. Do initial heartbeat round (may be moved after 5) +7. Mark self as leader by updating the database +8. Reschedule and reconcile everything + +Some things to note from the steps above: +* The storage controller makes no changes to the cluster state before step (5) (i.e. no location config +calls to the pageserver and no compute notifications) +* Ask the current leader to step down before loading state from database so we don't get a lost update +if the transactions overlap. +* Before loading the observed state at step (3), cross-validate against the database. If validation fails, +fall back to asking the pageservers about their current locations. +* Database migrations should only run **after** the previous instance steps down (or the step down times out). + + +[1] The API call might fail because there's no storage controller running (i.e. [restart](#storage-controller-crash-or-restart)), +so we don't want to extend the unavailability period by much. We still want to retry since that's not the common case. + +### Handling Failures + +#### Storage Controller Crash Or Restart + +The storage controller may crash or be restarted outside of roll-outs. When a new pod is created, its call to +`/step_down` will fail since the previous leader is no longer reachable. In this case perform the pre-existing +start-up procedure and update the leader table (with the WHERE clause). If the update fails, the storage controller +exists and consistency is maintained. + +#### Previous Leader Crashes Before New Leader Readiness + +When the previous leader (P1) crashes before the new leader (P2) passses the readiness check, Kubernetes will +reconcile the old replica set and create a new pod for it (P1'). The `/step_down` API call will fail for P1' +(see [2]). + +Now we have two cases to consider: +* P2 updates the `leader` table first: The database update from P1' will fail and P1' will exit, or be terminated +by Kubernetes depending on timings. +* P1' updates the `leader` table first: The `hostname` field of the `leader` row stays the same, but the `start_timestamp` field changes. +The database update from P2 will fail (since `start_timestamp` does not match). P2 will exit and Kubernetes will +create a new replacement pod for it (P2'). Now the entire dance starts again, but with P1' as the leader and P2' as the incumbent. + +[2] P1 and P1' may (more likely than not) be the same pod and have the same hostname. The implementation +should avoid this self reference and fail the API call at the client if the persisted hostname matches +the current one. + +#### Previous Leader Crashes After New Leader Readiness + +The deployment's replica sets already satisfy the deployment's replica count requirements and the +Kubernetes deployment rollout will just clean up the dead pod. + +#### New Leader Crashes Before Pasing Readiness Check + +The deployment controller scales up the new replica sets by creating a new pod. The entire procedure is repeated +with the new pod. + +#### Network Partition Between New Pod and Previous Leader + +This feels very unlikely, but should be considered in any case. P2 (the new aspiring leader) fails the `/step_down` +API call into P1 (the current leader). P2 proceeds with the pre-existing startup procedure and updates the `leader` table. +Kubernetes will terminate P1, but there may be a brief period where both storage controller can drive reconciles. + +### Dealing With Split Brain Scenarios + +As we've seen in the previous section, we can end up with two storage controller running at the same time. The split brain +duration is not bounded since the Kubernetes controller might become partitioned from the pods (unlikely though). While these +scenarios are not fatal, they can cause tenant unavailability, so we'd like to reduce the chances of this happening. +The rest of this section sketches some safety measure. It's likely overkill to implement all of them however. + +### Ensure Leadership Before Producing Side Effects + +The storage controller has two types of side effects: location config requests into pageservers and compute notifications into the control plane. +Before issuing either, the storage controller could check that it is indeed still the leader by querying the database. Side effects might still be +applied if they race with the database updatem, but the situation will eventually be detected. The storage controller process should terminate in these cases. + +### Leadership Lease + +Up until now, the leadership defined by this RFC is static. In order to bound the length of the split brain scenario, we could require the leadership +to be renewed periodically. Two new columns would be added to the leaders table: +1. `last_renewed` - timestamp indicating when the lease was last renewed +2. `lease_duration` - duration indicating the amount of time after which the lease expires + +The leader periodically attempts to renew the lease by checking that it is in fact still the legitimate leader and updating `last_renewed` in the +same transaction. If the update fails, the process exits. New storage controller instances wishing to become leaders must wait for the current lease +to expire before acquiring leadership if they have not succesfully received a response to the `/step_down` request. + +### Notify Pageserver Of Storage Controller Term + +Each time that leadership changes, we can bump a `term` integer column in the `leader` table. This term uniquely identifies a leader. +Location config requests and re-attach responses can include this term. On the pageserver side, keep the latest term in memory and refuse +anything which contains a stale term (i.e. smaller than the current one). + +### Observability + +* The storage controller should expose a metric which describes it's state (`Active | WarmingUp | SteppedDown`). +Per region alerts should be added on this metric which triggers when: + + no storage controller has been in the `Active` state for an extended period of time + + more than one storage controllers are in the `Active` state + +* An alert that periodically verifies that the `leader` table is in sync with the metric above would be very useful. +We'd have to expose the storage controller read only database to Grafana (perhaps it is already done). + +## Alternatives + +### Kubernetes Leases + +Kubernetes has a [lease primitive](https://kubernetes.io/docs/concepts/architecture/leases/) which can be used to implement leader election. +Only one instance may hold a lease at any given time. This lease needs to be periodically renewed and has an expiration period. + +In our case, it would work something like this: +* `/step_down` deletes the lease or stops it from renewing +* lease acquisition becomes part of the start-up procedure + +The kubert crate implements a [lightweight lease API](https://docs.rs/kubert/latest/kubert/lease/struct.LeaseManager.html), but it's still +not exactly trivial to implement. + +This approach has the benefit of baked in observability (`kubectl describe lease`), but: +* We offload the responsibility to Kubernetes which makes it harder to debug when things go wrong. +* More code surface than the simple "row in database" approach. Also, most of this code would be in +a dependency not subject to code review, etc. +* Hard to test. Our testing infra does not run the storage controller in Kubernetes and changing it do +so is not simple and complictes and the test set-up. + +To my mind, the "row in database" approach is straightforward enough that we don't have to offload this +to something external. diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index a9a57d77ce..345abd69b6 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -56,6 +56,8 @@ pub struct NodeRegisterRequest { pub listen_http_addr: String, pub listen_http_port: u16, + + pub availability_zone_id: Option, } #[derive(Serialize, Deserialize)] diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 77da58d63e..77d744e4da 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -108,14 +108,41 @@ impl Key { } } + /// This function checks more extensively what keys we can take on the write path. + /// If a key beginning with 00 does not have a global/default tablespace OID, it + /// will be rejected on the write path. + #[allow(dead_code)] + pub fn is_valid_key_on_write_path_strong(&self) -> bool { + use postgres_ffi::pg_constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID}; + if !self.is_i128_representable() { + return false; + } + if self.field1 == 0 + && !(self.field2 == GLOBALTABLESPACE_OID + || self.field2 == DEFAULTTABLESPACE_OID + || self.field2 == 0) + { + return false; // User defined tablespaces are not supported + } + true + } + + /// This is a weaker version of `is_valid_key_on_write_path_strong` that simply + /// checks if the key is i128 representable. Note that some keys can be successfully + /// ingested into the pageserver, but will cause errors on generating basebackup. + pub fn is_valid_key_on_write_path(&self) -> bool { + self.is_i128_representable() + } + + pub fn is_i128_representable(&self) -> bool { + self.field2 <= 0xFFFF || self.field2 == 0xFFFFFFFF || self.field2 == 0x22222222 + } + /// 'field2' is used to store tablespaceid for relations and small enum numbers for other relish. /// As long as Neon does not support tablespace (because of lack of access to local file system), /// we can assume that only some predefined namespace OIDs are used which can fit in u16 pub fn to_i128(&self) -> i128 { - assert!( - self.field2 <= 0xFFFF || self.field2 == 0xFFFFFFFF || self.field2 == 0x22222222, - "invalid key: {self}", - ); + assert!(self.is_i128_representable(), "invalid key: {self}"); (((self.field1 & 0x7F) as i128) << 120) | (((self.field2 & 0xFFFF) as i128) << 104) | ((self.field3 as i128) << 72) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 4cab56771b..1d896863df 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -7,7 +7,7 @@ pub use utilization::PageserverUtilization; use std::{ collections::HashMap, io::{BufRead, Read}, - num::{NonZeroU64, NonZeroUsize}, + num::{NonZeroU32, NonZeroU64, NonZeroUsize}, str::FromStr, sync::atomic::AtomicUsize, time::{Duration, SystemTime}, @@ -486,12 +486,11 @@ pub struct EvictionPolicyLayerAccessThreshold { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct ThrottleConfig { pub task_kinds: Vec, // TaskKind - pub initial: usize, + pub initial: u32, #[serde(with = "humantime_serde")] pub refill_interval: Duration, - pub refill_amount: NonZeroUsize, - pub max: usize, - pub fair: bool, + pub refill_amount: NonZeroU32, + pub max: u32, } impl ThrottleConfig { @@ -501,9 +500,8 @@ impl ThrottleConfig { // other values don't matter with emtpy `task_kinds`. initial: 0, refill_interval: Duration::from_millis(1), - refill_amount: NonZeroUsize::new(1).unwrap(), + refill_amount: NonZeroU32::new(1).unwrap(), max: 1, - fair: true, } } /// The requests per second allowed by the given config. @@ -1063,7 +1061,7 @@ impl TryFrom for PagestreamBeMessageTag { } } -// In the V2 protocol version, a GetPage request contains two LSN values: +// A GetPage request contains two LSN values: // // request_lsn: Get the page version at this point in time. Lsn::Max is a special value that means // "get the latest version present". It's used by the primary server, which knows that no one else @@ -1076,7 +1074,7 @@ impl TryFrom for PagestreamBeMessageTag { // passing an earlier LSN can speed up the request, by allowing the pageserver to process the // request without waiting for 'request_lsn' to arrive. // -// The legacy V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was +// The now-defunct V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was // sufficient for the primary; the 'lsn' was equivalent to the 'not_modified_since' value, and // 'latest' was set to true. The V2 interface was added because there was no correct way for a // standby to request a page at a particular non-latest LSN, and also include the @@ -1084,15 +1082,11 @@ impl TryFrom for PagestreamBeMessageTag { // request, if the standby knows that the page hasn't been modified since, and risk getting an error // if that LSN has fallen behind the GC horizon, or requesting the current replay LSN, which could // require the pageserver unnecessarily to wait for the WAL to arrive up to that point. The new V2 -// interface allows sending both LSNs, and let the pageserver do the right thing. There is no +// interface allows sending both LSNs, and let the pageserver do the right thing. There was no // difference in the responses between V1 and V2. // -// The Request structs below reflect the V2 interface. If V1 is used, the parse function -// maps the old format requests to the new format. -// #[derive(Clone, Copy)] pub enum PagestreamProtocolVersion { - V1, V2, } @@ -1231,36 +1225,17 @@ impl PagestreamFeMessage { bytes.into() } - pub fn parse( - body: &mut R, - protocol_version: PagestreamProtocolVersion, - ) -> anyhow::Result { + pub fn parse(body: &mut R) -> anyhow::Result { // these correspond to the NeonMessageTag enum in pagestore_client.h // // TODO: consider using protobuf or serde bincode for less error prone // serialization. let msg_tag = body.read_u8()?; - let (request_lsn, not_modified_since) = match protocol_version { - PagestreamProtocolVersion::V2 => ( - Lsn::from(body.read_u64::()?), - Lsn::from(body.read_u64::()?), - ), - PagestreamProtocolVersion::V1 => { - // In the old protocol, each message starts with a boolean 'latest' flag, - // followed by 'lsn'. Convert that to the two LSNs, 'request_lsn' and - // 'not_modified_since', used in the new protocol version. - let latest = body.read_u8()? != 0; - let request_lsn = Lsn::from(body.read_u64::()?); - if latest { - (Lsn::MAX, request_lsn) // get latest version - } else { - (request_lsn, request_lsn) // get version at specified LSN - } - } - }; + // these two fields are the same for every request type + let request_lsn = Lsn::from(body.read_u64::()?); + let not_modified_since = Lsn::from(body.read_u64::()?); - // The rest of the messages are the same between V1 and V2 match msg_tag { 0 => Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest { request_lsn, @@ -1468,9 +1443,7 @@ mod tests { ]; for msg in messages { let bytes = msg.serialize(); - let reconstructed = - PagestreamFeMessage::parse(&mut bytes.reader(), PagestreamProtocolVersion::V2) - .unwrap(); + let reconstructed = PagestreamFeMessage::parse(&mut bytes.reader()).unwrap(); assert!(msg == reconstructed); } } diff --git a/libs/postgres_ffi/src/lib.rs b/libs/postgres_ffi/src/lib.rs index 0940ad207f..9acb105e9b 100644 --- a/libs/postgres_ffi/src/lib.rs +++ b/libs/postgres_ffi/src/lib.rs @@ -136,9 +136,9 @@ pub const MAX_SEND_SIZE: usize = XLOG_BLCKSZ * 16; // Export some version independent functions that are used outside of this mod pub use v14::xlog_utils::encode_logical_message; -pub use v14::xlog_utils::from_pg_timestamp; pub use v14::xlog_utils::get_current_timestamp; pub use v14::xlog_utils::to_pg_timestamp; +pub use v14::xlog_utils::try_from_pg_timestamp; pub use v14::xlog_utils::XLogFileName; pub use v14::bindings::DBState_DB_SHUTDOWNED; diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index 9fe7e8198b..0cfd56962e 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -135,6 +135,8 @@ pub fn get_current_timestamp() -> TimestampTz { mod timestamp_conversions { use std::time::Duration; + use anyhow::Context; + use super::*; const UNIX_EPOCH_JDATE: u64 = 2440588; // == date2j(1970, 1, 1) @@ -154,18 +156,18 @@ mod timestamp_conversions { } } - pub fn from_pg_timestamp(time: TimestampTz) -> SystemTime { + pub fn try_from_pg_timestamp(time: TimestampTz) -> anyhow::Result { let time: u64 = time .try_into() - .expect("timestamp before millenium (postgres epoch)"); + .context("timestamp before millenium (postgres epoch)")?; let since_unix_epoch = time + SECS_DIFF_UNIX_TO_POSTGRES_EPOCH * USECS_PER_SEC; SystemTime::UNIX_EPOCH .checked_add(Duration::from_micros(since_unix_epoch)) - .expect("SystemTime overflow") + .context("SystemTime overflow") } } -pub use timestamp_conversions::{from_pg_timestamp, to_pg_timestamp}; +pub use timestamp_conversions::{to_pg_timestamp, try_from_pg_timestamp}; // Returns (aligned) end_lsn of the last record in data_dir with WAL segments. // start_lsn must point to some previously known record boundary (beginning of @@ -545,14 +547,14 @@ mod tests { #[test] fn test_ts_conversion() { let now = SystemTime::now(); - let round_trip = from_pg_timestamp(to_pg_timestamp(now)); + let round_trip = try_from_pg_timestamp(to_pg_timestamp(now)).unwrap(); let now_since = now.duration_since(SystemTime::UNIX_EPOCH).unwrap(); let round_trip_since = round_trip.duration_since(SystemTime::UNIX_EPOCH).unwrap(); assert_eq!(now_since.as_micros(), round_trip_since.as_micros()); let now_pg = get_current_timestamp(); - let round_trip_pg = to_pg_timestamp(from_pg_timestamp(now_pg)); + let round_trip_pg = to_pg_timestamp(try_from_pg_timestamp(now_pg).unwrap()); assert_eq!(now_pg, round_trip_pg); } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 6e593eeac1..19deaab63f 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -14,7 +14,6 @@ testing = ["fail/failpoints"] arc-swap.workspace = true sentry.workspace = true async-compression.workspace = true -async-trait.workspace = true anyhow.workspace = true bincode.workspace = true bytes.workspace = true @@ -26,7 +25,6 @@ hyper = { workspace = true, features = ["full"] } fail.workspace = true futures = { workspace = true} jsonwebtoken.workspace = true -leaky-bucket.workspace = true nix.workspace = true once_cell.workspace = true pin-project-lite.workspace = true diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs new file mode 100644 index 0000000000..a120dc0ac5 --- /dev/null +++ b/libs/utils/src/leaky_bucket.rs @@ -0,0 +1,280 @@ +//! This module implements the Generic Cell Rate Algorithm for a simplified +//! version of the Leaky Bucket rate limiting system. +//! +//! # Leaky Bucket +//! +//! If the bucket is full, no new requests are allowed and are throttled/errored. +//! If the bucket is partially full/empty, new requests are added to the bucket in +//! terms of "tokens". +//! +//! Over time, tokens are removed from the bucket, naturally allowing new requests at a steady rate. +//! +//! The bucket size tunes the burst support. The drain rate tunes the steady-rate requests per second. +//! +//! # [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm) +//! +//! GCRA is a continuous rate leaky-bucket impl that stores minimal state and requires +//! no background jobs to drain tokens, as the design utilises timestamps to drain automatically over time. +//! +//! We store an "empty_at" timestamp as the only state. As time progresses, we will naturally approach +//! the empty state. The full-bucket state is calculated from `empty_at - config.bucket_width`. +//! +//! Another explaination can be found here: + +use std::{sync::Mutex, time::Duration}; + +use tokio::{sync::Notify, time::Instant}; + +pub struct LeakyBucketConfig { + /// This is the "time cost" of a single request unit. + /// Should loosely represent how long it takes to handle a request unit in active resource time. + /// Loosely speaking this is the inverse of the steady-rate requests-per-second + pub cost: Duration, + + /// total size of the bucket + pub bucket_width: Duration, +} + +impl LeakyBucketConfig { + pub fn new(rps: f64, bucket_size: f64) -> Self { + let cost = Duration::from_secs_f64(rps.recip()); + let bucket_width = cost.mul_f64(bucket_size); + Self { cost, bucket_width } + } +} + +pub struct LeakyBucketState { + /// Bucket is represented by `allow_at..empty_at` where `allow_at = empty_at - config.bucket_width`. + /// + /// At any given time, `empty_at - now` represents the number of tokens in the bucket, multiplied by the "time_cost". + /// Adding `n` tokens to the bucket is done by moving `empty_at` forward by `n * config.time_cost`. + /// If `now < allow_at`, the bucket is considered filled and cannot accept any more tokens. + /// Draining the bucket will happen naturally as `now` moves forward. + /// + /// Let `n` be some "time cost" for the request, + /// If now is after empty_at, the bucket is empty and the empty_at is reset to now, + /// If now is within the `bucket window + n`, we are within time budget. + /// If now is before the `bucket window + n`, we have run out of budget. + /// + /// This is inspired by the generic cell rate algorithm (GCRA) and works + /// exactly the same as a leaky-bucket. + pub empty_at: Instant, +} + +impl LeakyBucketState { + pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self { + LeakyBucketState { + empty_at: Instant::now() + config.cost.mul_f64(initial_tokens), + } + } + + pub fn bucket_is_empty(&self, now: Instant) -> bool { + // if self.end is after now, the bucket is not empty + self.empty_at <= now + } + + /// Immediately adds tokens to the bucket, if there is space. + /// + /// In a scenario where you are waiting for available rate, + /// rather than just erroring immediately, `started` corresponds to when this waiting started. + /// + /// `n` is the number of tokens that will be filled in the bucket. + /// + /// # Errors + /// + /// If there is not enough space, no tokens are added. Instead, an error is returned with the time when + /// there will be space again. + pub fn add_tokens( + &mut self, + config: &LeakyBucketConfig, + started: Instant, + n: f64, + ) -> Result<(), Instant> { + let now = Instant::now(); + + // invariant: started <= now + debug_assert!(started <= now); + + // If the bucket was empty when we started our search, + // we should update the `empty_at` value accordingly. + // this prevents us from having negative tokens in the bucket. + let mut empty_at = self.empty_at; + if empty_at < started { + empty_at = started; + } + + let n = config.cost.mul_f64(n); + let new_empty_at = empty_at + n; + let allow_at = new_empty_at.checked_sub(config.bucket_width); + + // empty_at + // allow_at | new_empty_at + // / | / + // -------o-[---------o-|--]--------- + // now1 ^ now2 ^ + // + // at now1, the bucket would be completely filled if we add n tokens. + // at now2, the bucket would be partially filled if we add n tokens. + + match allow_at { + Some(allow_at) if now < allow_at => Err(allow_at), + _ => { + self.empty_at = new_empty_at; + Ok(()) + } + } + } +} + +pub struct RateLimiter { + pub config: LeakyBucketConfig, + pub state: Mutex, + /// a queue to provide this fair ordering. + pub queue: Notify, +} + +struct Requeue<'a>(&'a Notify); + +impl Drop for Requeue<'_> { + fn drop(&mut self) { + self.0.notify_one(); + } +} + +impl RateLimiter { + pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self { + RateLimiter { + state: Mutex::new(LeakyBucketState::with_initial_tokens( + &config, + initial_tokens, + )), + config, + queue: { + let queue = Notify::new(); + queue.notify_one(); + queue + }, + } + } + + pub fn steady_rps(&self) -> f64 { + self.config.cost.as_secs_f64().recip() + } + + /// returns true if we did throttle + pub async fn acquire(&self, count: usize) -> bool { + let mut throttled = false; + + let start = tokio::time::Instant::now(); + + // wait until we are the first in the queue + let mut notified = std::pin::pin!(self.queue.notified()); + if !notified.as_mut().enable() { + throttled = true; + notified.await; + } + + // notify the next waiter in the queue when we are done. + let _guard = Requeue(&self.queue); + + loop { + let res = self + .state + .lock() + .unwrap() + .add_tokens(&self.config, start, count as f64); + match res { + Ok(()) => return throttled, + Err(ready_at) => { + throttled = true; + tokio::time::sleep_until(ready_at).await; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::time::Instant; + + use super::{LeakyBucketConfig, LeakyBucketState}; + + #[tokio::test(start_paused = true)] + async fn check() { + let config = LeakyBucketConfig { + // average 100rps + cost: Duration::from_millis(10), + // burst up to 100 requests + bucket_width: Duration::from_millis(1000), + }; + + let mut state = LeakyBucketState { + empty_at: Instant::now(), + }; + + // supports burst + { + // should work for 100 requests this instant + for _ in 0..100 { + state.add_tokens(&config, Instant::now(), 1.0).unwrap(); + } + let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); + assert_eq!(ready - Instant::now(), Duration::from_millis(10)); + } + + // doesn't overfill + { + // after 1s we should have an empty bucket again. + tokio::time::advance(Duration::from_secs(1)).await; + assert!(state.bucket_is_empty(Instant::now())); + + // after 1s more, we should not over count the tokens and allow more than 200 requests. + tokio::time::advance(Duration::from_secs(1)).await; + for _ in 0..100 { + state.add_tokens(&config, Instant::now(), 1.0).unwrap(); + } + let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); + assert_eq!(ready - Instant::now(), Duration::from_millis(10)); + } + + // supports sustained rate over a long period + { + tokio::time::advance(Duration::from_secs(1)).await; + + // should sustain 100rps + for _ in 0..2000 { + tokio::time::advance(Duration::from_millis(10)).await; + state.add_tokens(&config, Instant::now(), 1.0).unwrap(); + } + } + + // supports requesting more tokens than can be stored in the bucket + // we just wait a little bit longer upfront. + { + // start the bucket completely empty + tokio::time::advance(Duration::from_secs(5)).await; + assert!(state.bucket_is_empty(Instant::now())); + + // requesting 200 tokens of space should take 200*cost = 2s + // but we already have 1s available, so we wait 1s from start. + let start = Instant::now(); + + let ready = state.add_tokens(&config, start, 200.0).unwrap_err(); + assert_eq!(ready - Instant::now(), Duration::from_secs(1)); + + tokio::time::advance(Duration::from_millis(500)).await; + let ready = state.add_tokens(&config, start, 200.0).unwrap_err(); + assert_eq!(ready - Instant::now(), Duration::from_millis(500)); + + tokio::time::advance(Duration::from_millis(500)).await; + state.add_tokens(&config, start, 200.0).unwrap(); + + // bucket should be completely full now + let ready = state.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); + assert_eq!(ready - Instant::now(), Duration::from_millis(10)); + } + } +} diff --git a/libs/utils/src/lib.rs b/libs/utils/src/lib.rs index f4fc0ba57b..218dd468b1 100644 --- a/libs/utils/src/lib.rs +++ b/libs/utils/src/lib.rs @@ -71,6 +71,7 @@ pub mod postgres_client; pub mod tracing_span_assert; +pub mod leaky_bucket; pub mod rate_limit; /// Simple once-barrier and a guard which keeps barrier awaiting. diff --git a/libs/utils/src/rate_limit.rs b/libs/utils/src/rate_limit.rs index 557955bb88..f3f8f219e3 100644 --- a/libs/utils/src/rate_limit.rs +++ b/libs/utils/src/rate_limit.rs @@ -5,6 +5,15 @@ use std::time::{Duration, Instant}; pub struct RateLimit { last: Option, interval: Duration, + dropped: u64, +} + +pub struct RateLimitStats(u64); + +impl std::fmt::Display for RateLimitStats { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} dropped calls", self.0) + } } impl RateLimit { @@ -12,20 +21,27 @@ impl RateLimit { Self { last: None, interval, + dropped: 0, } } /// Call `f` if the rate limit allows. /// Don't call it otherwise. pub fn call(&mut self, f: F) { + self.call2(|_| f()) + } + + pub fn call2(&mut self, f: F) { let now = Instant::now(); match self.last { Some(last) if now - last <= self.interval => { // ratelimit + self.dropped += 1; } _ => { self.last = Some(now); - f(); + f(RateLimitStats(self.dropped)); + self.dropped = 0; } } } diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 0e748ee3db..9c02ce3fbc 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -16,6 +16,7 @@ arc-swap.workspace = true async-compression.workspace = true async-stream.workspace = true async-trait.workspace = true +bit_field.workspace = true byteorder.workspace = true bytes.workspace = true camino.workspace = true @@ -36,7 +37,6 @@ humantime.workspace = true humantime-serde.workspace = true hyper.workspace = true itertools.workspace = true -leaky-bucket.workspace = true md5.workspace = true nix.workspace = true # hack to get the number of worker threads tokio uses @@ -52,6 +52,7 @@ rand.workspace = true range-set-blaze = { version = "0.1.16", features = ["alloc"] } regex.workspace = true scopeguard.workspace = true +send-future.workspace = true serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } serde_path_to_error.workspace = true diff --git a/pageserver/benches/bench_ingest.rs b/pageserver/benches/bench_ingest.rs index bd99f5289d..1be4391d81 100644 --- a/pageserver/benches/bench_ingest.rs +++ b/pageserver/benches/bench_ingest.rs @@ -4,7 +4,7 @@ use bytes::Bytes; use camino::Utf8PathBuf; use criterion::{criterion_group, criterion_main, Criterion}; use pageserver::{ - config::PageServerConf, + config::{defaults::DEFAULT_IO_BUFFER_ALIGNMENT, PageServerConf}, context::{DownloadBehavior, RequestContext}, l0_flush::{L0FlushConfig, L0FlushGlobalState}, page_cache, @@ -103,13 +103,13 @@ async fn ingest( batch.push((key.to_compact(), lsn, data_ser_size, data.clone())); if batch.len() >= BATCH_SIZE { let this_batch = std::mem::take(&mut batch); - let serialized = SerializedBatch::from_values(this_batch); + let serialized = SerializedBatch::from_values(this_batch).unwrap(); layer.put_batch(serialized, &ctx).await?; } } if !batch.is_empty() { let this_batch = std::mem::take(&mut batch); - let serialized = SerializedBatch::from_values(this_batch); + let serialized = SerializedBatch::from_values(this_batch).unwrap(); layer.put_batch(serialized, &ctx).await?; } layer.freeze(lsn + 1).await; @@ -164,7 +164,11 @@ fn criterion_benchmark(c: &mut Criterion) { let conf: &'static PageServerConf = Box::leak(Box::new( pageserver::config::PageServerConf::dummy_conf(temp_dir.path().to_path_buf()), )); - virtual_file::init(16384, virtual_file::io_engine_for_bench()); + virtual_file::init( + 16384, + virtual_file::io_engine_for_bench(), + DEFAULT_IO_BUFFER_ALIGNMENT, + ); page_cache::init(conf.page_cache_size); { diff --git a/pageserver/client/Cargo.toml b/pageserver/client/Cargo.toml index a938367334..d9b36bf3d4 100644 --- a/pageserver/client/Cargo.toml +++ b/pageserver/client/Cargo.toml @@ -7,7 +7,6 @@ license.workspace = true [dependencies] pageserver_api.workspace = true thiserror.workspace = true -async-trait.workspace = true reqwest = { workspace = true, features = [ "stream" ] } utils.workspace = true serde.workspace = true diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index ac3ff1bb89..71d36f3113 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -506,6 +506,16 @@ impl Client { .map_err(Error::ReceiveBody) } + /// Configs io buffer alignment at runtime. + pub async fn put_io_alignment(&self, align: usize) -> Result<()> { + let uri = format!("{}/v1/io_alignment", self.mgmt_api_endpoint); + self.request(Method::PUT, uri, align) + .await? + .json() + .await + .map_err(Error::ReceiveBody) + } + pub async fn get_utilization(&self) -> Result { let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint); self.get(uri) diff --git a/pageserver/ctl/src/layer_map_analyzer.rs b/pageserver/ctl/src/layer_map_analyzer.rs index b4bb239f44..8092c203c3 100644 --- a/pageserver/ctl/src/layer_map_analyzer.rs +++ b/pageserver/ctl/src/layer_map_analyzer.rs @@ -4,6 +4,7 @@ use anyhow::Result; use camino::{Utf8Path, Utf8PathBuf}; +use pageserver::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT; use pageserver::context::{DownloadBehavior, RequestContext}; use pageserver::task_mgr::TaskKind; use pageserver::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME}; @@ -144,7 +145,11 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> { let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); // Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree. - pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); + pageserver::virtual_file::init( + 10, + virtual_file::api::IoEngineKind::StdFs, + DEFAULT_IO_BUFFER_ALIGNMENT, + ); pageserver::page_cache::init(100); let mut total_delta_layers = 0usize; diff --git a/pageserver/ctl/src/layers.rs b/pageserver/ctl/src/layers.rs index 3611b0baab..e0f978eaa2 100644 --- a/pageserver/ctl/src/layers.rs +++ b/pageserver/ctl/src/layers.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use anyhow::Result; use camino::{Utf8Path, Utf8PathBuf}; use clap::Subcommand; +use pageserver::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT; use pageserver::context::{DownloadBehavior, RequestContext}; use pageserver::task_mgr::TaskKind; use pageserver::tenant::block_io::BlockCursor; @@ -59,7 +60,7 @@ pub(crate) enum LayerCmd { async fn read_delta_file(path: impl AsRef, ctx: &RequestContext) -> Result<()> { let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path"); - virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); + virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, 1); page_cache::init(100); let file = VirtualFile::open(path, ctx).await?; let file_id = page_cache::next_file_id(); @@ -89,6 +90,7 @@ async fn read_delta_file(path: impl AsRef, ctx: &RequestContext) -> Result for (k, v) in all { let value = cursor.read_blob(v.pos(), ctx).await?; println!("key:{} value_len:{}", k, value.len()); + assert!(k.is_i128_representable(), "invalid key: "); } // TODO(chi): special handling for last key? Ok(()) @@ -189,7 +191,11 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { new_tenant_id, new_timeline_id, } => { - pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); + pageserver::virtual_file::init( + 10, + virtual_file::api::IoEngineKind::StdFs, + DEFAULT_IO_BUFFER_ALIGNMENT, + ); pageserver::page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); diff --git a/pageserver/ctl/src/main.rs b/pageserver/ctl/src/main.rs index 3fabf62987..7a6c7675bb 100644 --- a/pageserver/ctl/src/main.rs +++ b/pageserver/ctl/src/main.rs @@ -20,6 +20,7 @@ use clap::{Parser, Subcommand}; use index_part::IndexPartCmd; use layers::LayerCmd; use pageserver::{ + config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, context::{DownloadBehavior, RequestContext}, page_cache, task_mgr::TaskKind, @@ -205,7 +206,11 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> { async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> { // Basic initialization of things that don't change after startup - virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); + virtual_file::init( + 10, + virtual_file::api::IoEngineKind::StdFs, + DEFAULT_IO_BUFFER_ALIGNMENT, + ); page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); dump_layerfile_from_path(path, true, &ctx).await diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 4992f37465..ac4a732377 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -58,6 +58,11 @@ pub(crate) struct Args { /// [`pageserver_api::models::virtual_file::IoEngineKind`]. #[clap(long)] set_io_engine: Option, + + /// Before starting the benchmark, live-reconfigure the pageserver to use specified alignment for io buffers. + #[clap(long)] + set_io_alignment: Option, + targets: Option>, } @@ -124,6 +129,10 @@ async fn main_impl( mgmt_api_client.put_io_engine(engine_str).await?; } + if let Some(align) = args.set_io_alignment { + mgmt_api_client.put_io_alignment(align).await?; + } + // discover targets let timelines: Vec = crate::util::cli::targets::discover( &mgmt_api_client, diff --git a/pageserver/src/assert_u64_eq_usize.rs b/pageserver/src/assert_u64_eq_usize.rs new file mode 100644 index 0000000000..66ca7fd057 --- /dev/null +++ b/pageserver/src/assert_u64_eq_usize.rs @@ -0,0 +1,39 @@ +//! `u64`` and `usize`` aren't guaranteed to be identical in Rust, but life is much simpler if that's the case. + +pub(crate) const _ASSERT_U64_EQ_USIZE: () = { + if std::mem::size_of::() != std::mem::size_of::() { + panic!("the traits defined in this module assume that usize and u64 can be converted to each other without loss of information"); + } +}; + +pub(crate) trait U64IsUsize { + fn into_usize(self) -> usize; +} + +impl U64IsUsize for u64 { + #[inline(always)] + fn into_usize(self) -> usize { + #[allow(clippy::let_unit_value)] + let _ = _ASSERT_U64_EQ_USIZE; + self as usize + } +} + +pub(crate) trait UsizeIsU64 { + fn into_u64(self) -> u64; +} + +impl UsizeIsU64 for usize { + #[inline(always)] + fn into_u64(self) -> u64 { + #[allow(clippy::let_unit_value)] + let _ = _ASSERT_U64_EQ_USIZE; + self as u64 + } +} + +pub const fn u64_to_usize(x: u64) -> usize { + #[allow(clippy::let_unit_value)] + let _ = _ASSERT_U64_EQ_USIZE; + x as usize +} diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 7d404e50a5..850bd87b95 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -125,6 +125,7 @@ fn main() -> anyhow::Result<()> { info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings"); info!(?conf.compact_level0_phase1_value_access, "starting with setting for compact_level0_phase1_value_access"); + info!(?conf.io_buffer_alignment, "starting with setting for IO buffer alignment"); // The tenants directory contains all the pageserver local disk state. // Create if not exists and make sure all the contents are durable before proceeding. @@ -182,7 +183,11 @@ fn main() -> anyhow::Result<()> { let scenario = failpoint_support::init(); // Basic initialization of things that don't change after startup - virtual_file::init(conf.max_file_descriptors, conf.virtual_file_io_engine); + virtual_file::init( + conf.max_file_descriptors, + conf.virtual_file_io_engine, + conf.io_buffer_alignment, + ); page_cache::init(conf.page_cache_size); start_pageserver(launch_ts, conf).context("Failed to start pageserver")?; diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 0ebaf78840..9e4530ba3c 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -31,6 +31,7 @@ use utils::{ use crate::l0_flush::L0FlushConfig; use crate::tenant::config::TenantConfOpt; +use crate::tenant::storage_layer::inmemory_layer::IndexEntry; use crate::tenant::timeline::compaction::CompactL0Phase1ValueAccess; use crate::tenant::vectored_blob_io::MaxVectoredReadBytes; use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME}; @@ -95,6 +96,8 @@ pub mod defaults { pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0; + pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; + /// /// Default built-in configuration file. /// @@ -289,6 +292,8 @@ pub struct PageServerConf { /// Direct IO settings pub virtual_file_direct_io: virtual_file::DirectIoMode, + + pub io_buffer_alignment: usize, } /// We do not want to store this in a PageServerConf because the latter may be logged @@ -393,6 +398,8 @@ struct PageServerConfigBuilder { compact_level0_phase1_value_access: BuilderValue, virtual_file_direct_io: BuilderValue, + + io_buffer_alignment: BuilderValue, } impl PageServerConfigBuilder { @@ -481,6 +488,7 @@ impl PageServerConfigBuilder { l0_flush: Set(L0FlushConfig::default()), compact_level0_phase1_value_access: Set(CompactL0Phase1ValueAccess::default()), virtual_file_direct_io: Set(virtual_file::DirectIoMode::default()), + io_buffer_alignment: Set(DEFAULT_IO_BUFFER_ALIGNMENT), } } } @@ -660,6 +668,10 @@ impl PageServerConfigBuilder { self.virtual_file_direct_io = BuilderValue::Set(value); } + pub fn io_buffer_alignment(&mut self, value: usize) { + self.io_buffer_alignment = BuilderValue::Set(value); + } + pub fn build(self, id: NodeId) -> anyhow::Result { let default = Self::default_values(); @@ -716,6 +728,7 @@ impl PageServerConfigBuilder { l0_flush, compact_level0_phase1_value_access, virtual_file_direct_io, + io_buffer_alignment, } CUSTOM LOGIC { @@ -985,6 +998,9 @@ impl PageServerConf { "virtual_file_direct_io" => { builder.virtual_file_direct_io(utils::toml_edit_ext::deserialize_item(item).context("virtual_file_direct_io")?) } + "io_buffer_alignment" => { + builder.io_buffer_alignment(parse_toml_u64("io_buffer_alignment", item)? as usize) + } _ => bail!("unrecognized pageserver option '{key}'"), } } @@ -1005,6 +1021,15 @@ impl PageServerConf { conf.default_tenant_conf = t_conf.merge(TenantConf::default()); + IndexEntry::validate_checkpoint_distance(conf.default_tenant_conf.checkpoint_distance) + .map_err(|msg| anyhow::anyhow!("{msg}")) + .with_context(|| { + format!( + "effective checkpoint distance is unsupported: {}", + conf.default_tenant_conf.checkpoint_distance + ) + })?; + Ok(conf) } @@ -1068,6 +1093,7 @@ impl PageServerConf { l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), virtual_file_direct_io: virtual_file::DirectIoMode::default(), + io_buffer_alignment: defaults::DEFAULT_IO_BUFFER_ALIGNMENT, } } } @@ -1308,6 +1334,7 @@ background_task_maximum_delay = '334 s' l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), virtual_file_direct_io: virtual_file::DirectIoMode::default(), + io_buffer_alignment: defaults::DEFAULT_IO_BUFFER_ALIGNMENT, }, "Correct defaults should be used when no config values are provided" ); @@ -1381,6 +1408,7 @@ background_task_maximum_delay = '334 s' l0_flush: L0FlushConfig::default(), compact_level0_phase1_value_access: CompactL0Phase1ValueAccess::default(), virtual_file_direct_io: virtual_file::DirectIoMode::default(), + io_buffer_alignment: defaults::DEFAULT_IO_BUFFER_ALIGNMENT, }, "Should be able to parse all basic config values correctly" ); diff --git a/pageserver/src/control_plane_client.rs b/pageserver/src/control_plane_client.rs index b5d9267d79..56a536c387 100644 --- a/pageserver/src/control_plane_client.rs +++ b/pageserver/src/control_plane_client.rs @@ -141,12 +141,18 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient { m.other ); + let az_id = m + .other + .get("availability_zone_id") + .and_then(|jv| jv.as_str().map(|str| str.to_owned())); + Some(NodeRegisterRequest { node_id: conf.id, listen_pg_addr: m.postgres_host, listen_pg_port: m.postgres_port, listen_http_addr: m.http_host, listen_http_port: m.http_port, + availability_zone_id: az_id, }) } Err(e) => { diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index cbcc162b32..8cf2c99c09 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -324,6 +324,9 @@ impl From for ApiError { match value { NotFound => ApiError::NotFound(anyhow::anyhow!("timeline not found").into()), Timeout => ApiError::Timeout("hit pageserver internal timeout".into()), + e @ HasArchivedParent(_) => { + ApiError::PreconditionFailed(e.to_string().into_boxed_str()) + } HasUnarchivedChildren(children) => ApiError::PreconditionFailed( format!( "Cannot archive timeline which has non-archived child timelines: {children:?}" @@ -871,7 +874,10 @@ async fn get_timestamp_of_lsn_handler( match result { Some(time) => { - let time = format_rfc3339(postgres_ffi::from_pg_timestamp(time)).to_string(); + let time = format_rfc3339( + postgres_ffi::try_from_pg_timestamp(time).map_err(ApiError::InternalServerError)?, + ) + .to_string(); json_response(StatusCode::OK, time) } None => Err(ApiError::NotFound( @@ -1727,6 +1733,10 @@ async fn timeline_compact_handler( if Some(true) == parse_query_param::<_, bool>(&request, "enhanced_gc_bottom_most_compaction")? { flags |= CompactFlags::EnhancedGcBottomMostCompaction; } + if Some(true) == parse_query_param::<_, bool>(&request, "dry_run")? { + flags |= CompactFlags::DryRun; + } + let wait_until_uploaded = parse_query_param::<_, bool>(&request, "wait_until_uploaded")?.unwrap_or(false); @@ -2344,6 +2354,20 @@ async fn put_io_engine_handler( json_response(StatusCode::OK, ()) } +async fn put_io_alignment_handler( + mut r: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + check_permission(&r, None)?; + let align: usize = json_request(&mut r).await?; + crate::virtual_file::set_io_buffer_alignment(align).map_err(|align| { + ApiError::PreconditionFailed( + format!("Requested io alignment ({align}) is not a power of two").into(), + ) + })?; + json_response(StatusCode::OK, ()) +} + /// Polled by control plane. /// /// See [`crate::utilization`]. @@ -3031,6 +3055,9 @@ pub fn make_router( |r| api_handler(r, timeline_collect_keyspace), ) .put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler)) + .put("/v1/io_alignment", |r| { + api_handler(r, put_io_alignment_handler) + }) .put( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/force_aux_policy_switch", |r| api_handler(r, force_aux_policy_switch_handler), diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index dbfc9f3544..7a9cf495c7 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -16,6 +16,7 @@ pub mod l0_flush; use futures::{stream::FuturesUnordered, StreamExt}; pub use pageserver_api::keyspace; use tokio_util::sync::CancellationToken; +mod assert_u64_eq_usize; pub mod aux_file; pub mod metrics; pub mod page_cache; diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 1f8634df93..c4011d593c 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1552,7 +1552,6 @@ pub(crate) static LIVE_CONNECTIONS: Lazy = Lazy::new(|| { #[derive(Clone, Copy, enum_map::Enum, IntoStaticStr)] pub(crate) enum ComputeCommandKind { PageStreamV2, - PageStream, Basebackup, Fullbackup, LeaseLsn, diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 81294291a9..39c6a6fb74 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -557,7 +557,7 @@ impl PageServerHandler { pgb: &mut PostgresBackend, tenant_id: TenantId, timeline_id: TimelineId, - protocol_version: PagestreamProtocolVersion, + _protocol_version: PagestreamProtocolVersion, ctx: RequestContext, ) -> Result<(), QueryError> where @@ -601,8 +601,7 @@ impl PageServerHandler { fail::fail_point!("ps::handle-pagerequest-message"); // parse request - let neon_fe_msg = - PagestreamFeMessage::parse(&mut copy_data_bytes.reader(), protocol_version)?; + let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader())?; // invoke handler function let (handler_result, span) = match neon_fe_msg { @@ -754,16 +753,21 @@ impl PageServerHandler { } if request_lsn < **latest_gc_cutoff_lsn { - // Check explicitly for INVALID just to get a less scary error message if the - // request is obviously bogus - return Err(if request_lsn == Lsn::INVALID { - PageStreamError::BadRequest("invalid LSN(0) in request".into()) - } else { - PageStreamError::BadRequest(format!( + let gc_info = &timeline.gc_info.read().unwrap(); + if !gc_info.leases.contains_key(&request_lsn) { + // The requested LSN is below gc cutoff and is not guarded by a lease. + + // Check explicitly for INVALID just to get a less scary error message if the + // request is obviously bogus + return Err(if request_lsn == Lsn::INVALID { + PageStreamError::BadRequest("invalid LSN(0) in request".into()) + } else { + PageStreamError::BadRequest(format!( "tried to request a page version that was garbage collected. requested at {} gc cutoff {}", request_lsn, **latest_gc_cutoff_lsn ).into()) - }); + }); + } } // Wait for WAL up to 'not_modified_since' to arrive, if necessary @@ -790,6 +794,8 @@ impl PageServerHandler { } } + /// Handles the lsn lease request. + /// If a lease cannot be obtained, the client will receive NULL. #[instrument(skip_all, fields(shard_id, %lsn))] async fn handle_make_lsn_lease( &mut self, @@ -812,19 +818,25 @@ impl PageServerHandler { .await?; set_tracing_field_shard_id(&timeline); - let lease = timeline.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)?; - let valid_until = lease - .valid_until - .duration_since(SystemTime::UNIX_EPOCH) - .map_err(|e| QueryError::Other(e.into()))?; + let lease = timeline + .make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx) + .inspect_err(|e| { + warn!("{e}"); + }) + .ok(); + let valid_until_str = lease.map(|l| { + l.valid_until + .duration_since(SystemTime::UNIX_EPOCH) + .expect("valid_until is earlier than UNIX_EPOCH") + .as_millis() + .to_string() + }); + let bytes = valid_until_str.as_ref().map(|x| x.as_bytes()); pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col( b"valid_until", )]))? - .write_message_noflush(&BeMessage::DataRow(&[Some( - &valid_until.as_millis().to_be_bytes(), - )]))? - .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?; + .write_message_noflush(&BeMessage::DataRow(&[bytes]))?; Ok(()) } @@ -1275,35 +1287,6 @@ where ctx, ) .await?; - } else if let Some(params) = parts.strip_prefix(&["pagestream"]) { - if params.len() != 2 { - return Err(QueryError::Other(anyhow::anyhow!( - "invalid param number for pagestream command" - ))); - } - let tenant_id = TenantId::from_str(params[0]) - .with_context(|| format!("Failed to parse tenant id from {}", params[0]))?; - let timeline_id = TimelineId::from_str(params[1]) - .with_context(|| format!("Failed to parse timeline id from {}", params[1]))?; - - tracing::Span::current() - .record("tenant_id", field::display(tenant_id)) - .record("timeline_id", field::display(timeline_id)); - - self.check_permission(Some(tenant_id))?; - - COMPUTE_COMMANDS_COUNTERS - .for_command(ComputeCommandKind::PageStream) - .inc(); - - self.handle_pagerequests( - pgb, - tenant_id, - timeline_id, - PagestreamProtocolVersion::V1, - ctx, - ) - .await?; } else if let Some(params) = parts.strip_prefix(&["basebackup"]) { if params.len() < 2 { return Err(QueryError::Other(anyhow::anyhow!( diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index b7110d69b6..edcbac970b 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -12,7 +12,7 @@ use crate::keyspace::{KeySpace, KeySpaceAccum}; use crate::span::debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id; use crate::walrecord::NeonWalRecord; use crate::{aux_file, repository::*}; -use anyhow::{ensure, Context}; +use anyhow::{bail, ensure, Context}; use bytes::{Buf, Bytes, BytesMut}; use enum_map::Enum; use pageserver_api::key::{ @@ -1791,6 +1791,11 @@ impl<'a> DatadirModification<'a> { // Flush relation and SLRU data blocks, keep metadata. let mut retained_pending_updates = HashMap::<_, Vec<_>>::new(); for (key, values) in self.pending_updates.drain() { + if !key.is_valid_key_on_write_path() { + bail!( + "the request contains data not supported by pageserver at TimelineWriter::put: {}", key + ); + } let mut write_batch = Vec::new(); for (lsn, value_ser_size, value) in values { if key.is_rel_block_key() || key.is_slru_block_key() { @@ -1843,10 +1848,13 @@ impl<'a> DatadirModification<'a> { .drain() .flat_map(|(key, values)| { values.into_iter().map(move |(lsn, val_ser_size, value)| { - (key.to_compact(), lsn, val_ser_size, value) + if !key.is_valid_key_on_write_path() { + bail!("the request contains data not supported by pageserver at TimelineWriter::put: {}", key); + } + Ok((key.to_compact(), lsn, val_ser_size, value)) }) }) - .collect::>(); + .collect::>>()?; writer.put_batch(batch, ctx).await?; } diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index ed9e001fd2..6a4e90dd55 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -146,6 +146,12 @@ impl FromStr for TokioRuntimeMode { } } +static TOKIO_THREAD_STACK_SIZE: Lazy = Lazy::new(|| { + env::var("NEON_PAGESERVER_TOKIO_THREAD_STACK_SIZE") + // the default 2MiB are insufficent, especially in debug mode + .unwrap_or_else(|| NonZeroUsize::new(4 * 1024 * 1024).unwrap()) +}); + static ONE_RUNTIME: Lazy> = Lazy::new(|| { let thread_name = "pageserver-tokio"; let Some(mode) = env::var("NEON_PAGESERVER_USE_ONE_RUNTIME") else { @@ -164,6 +170,7 @@ static ONE_RUNTIME: Lazy> = Lazy::new(|| { tokio::runtime::Builder::new_current_thread() .thread_name(thread_name) .enable_all() + .thread_stack_size(TOKIO_THREAD_STACK_SIZE.get()) .build() .expect("failed to create one single runtime") } @@ -173,6 +180,7 @@ static ONE_RUNTIME: Lazy> = Lazy::new(|| { .thread_name(thread_name) .enable_all() .worker_threads(num_workers.get()) + .thread_stack_size(TOKIO_THREAD_STACK_SIZE.get()) .build() .expect("failed to create one multi-threaded runtime") } @@ -199,6 +207,7 @@ macro_rules! pageserver_runtime { .thread_name($name) .worker_threads(TOKIO_WORKER_THREADS.get()) .enable_all() + .thread_stack_size(TOKIO_THREAD_STACK_SIZE.get()) .build() .expect(std::concat!("Failed to create runtime ", $name)) }); diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 0364d521b6..fb30857ddf 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -509,6 +509,9 @@ pub enum TimelineArchivalError { #[error("Timeout")] Timeout, + #[error("ancestor is archived: {}", .0)] + HasArchivedParent(TimelineId), + #[error("HasUnarchivedChildren")] HasUnarchivedChildren(Vec), @@ -524,6 +527,7 @@ impl Debug for TimelineArchivalError { match self { Self::NotFound => write!(f, "NotFound"), Self::Timeout => write!(f, "Timeout"), + Self::HasArchivedParent(p) => f.debug_tuple("HasArchivedParent").field(p).finish(), Self::HasUnarchivedChildren(c) => { f.debug_tuple("HasUnarchivedChildren").field(c).finish() } @@ -877,6 +881,12 @@ impl Tenant { }); }; + // TODO: should also be rejecting tenant conf changes that violate this check. + if let Err(e) = crate::tenant::storage_layer::inmemory_layer::IndexEntry::validate_checkpoint_distance(tenant_clone.get_checkpoint_distance()) { + make_broken(&tenant_clone, anyhow::anyhow!(e), BrokenVerbosity::Error); + return Ok(()); + } + let mut init_order = init_order; // take the completion because initial tenant loading will complete when all of // these tasks complete. @@ -1363,11 +1373,20 @@ impl Tenant { let timeline = { let timelines = self.timelines.lock().unwrap(); - let timeline = match timelines.get(&timeline_id) { - Some(t) => t, - None => return Err(TimelineArchivalError::NotFound), + let Some(timeline) = timelines.get(&timeline_id) else { + return Err(TimelineArchivalError::NotFound); }; + if state == TimelineArchivalState::Unarchived { + if let Some(ancestor_timeline) = timeline.ancestor_timeline() { + if ancestor_timeline.is_archived() == Some(true) { + return Err(TimelineArchivalError::HasArchivedParent( + ancestor_timeline.timeline_id, + )); + } + } + } + // Ensure that there are no non-archived child timelines let children: Vec = timelines .iter() diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index a245c99a88..dd70f6bbff 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -148,7 +148,7 @@ pub(super) const LEN_COMPRESSION_BIT_MASK: u8 = 0xf0; /// The maximum size of blobs we support. The highest few bits /// are reserved for compression and other further uses. -const MAX_SUPPORTED_LEN: usize = 0x0fff_ffff; +pub(crate) const MAX_SUPPORTED_BLOB_LEN: usize = 0x0fff_ffff; pub(super) const BYTE_UNCOMPRESSED: u8 = 0x80; pub(super) const BYTE_ZSTD: u8 = BYTE_UNCOMPRESSED | 0x10; @@ -326,7 +326,7 @@ impl BlobWriter { (self.write_all(io_buf.slice_len(), ctx).await, srcbuf) } else { // Write a 4-byte length header - if len > MAX_SUPPORTED_LEN { + if len > MAX_SUPPORTED_BLOB_LEN { return ( ( io_buf.slice_len(), diff --git a/pageserver/src/tenant/block_io.rs b/pageserver/src/tenant/block_io.rs index 601b095155..3afa3a86b9 100644 --- a/pageserver/src/tenant/block_io.rs +++ b/pageserver/src/tenant/block_io.rs @@ -2,7 +2,6 @@ //! Low-level Block-oriented I/O functions //! -use super::ephemeral_file::EphemeralFile; use super::storage_layer::delta_layer::{Adapter, DeltaLayerInner}; use crate::context::RequestContext; use crate::page_cache::{self, FileId, PageReadGuard, PageWriteGuard, ReadBufResult, PAGE_SZ}; @@ -81,9 +80,7 @@ impl<'a> Deref for BlockLease<'a> { /// Unlike traits, we also support the read function to be async though. pub(crate) enum BlockReaderRef<'a> { FileBlockReader(&'a FileBlockReader<'a>), - EphemeralFile(&'a EphemeralFile), Adapter(Adapter<&'a DeltaLayerInner>), - Slice(&'a [u8]), #[cfg(test)] TestDisk(&'a super::disk_btree::tests::TestDisk), #[cfg(test)] @@ -100,9 +97,7 @@ impl<'a> BlockReaderRef<'a> { use BlockReaderRef::*; match self { FileBlockReader(r) => r.read_blk(blknum, ctx).await, - EphemeralFile(r) => r.read_blk(blknum, ctx).await, Adapter(r) => r.read_blk(blknum, ctx).await, - Slice(s) => Self::read_blk_slice(s, blknum), #[cfg(test)] TestDisk(r) => r.read_blk(blknum), #[cfg(test)] @@ -111,24 +106,6 @@ impl<'a> BlockReaderRef<'a> { } } -impl<'a> BlockReaderRef<'a> { - fn read_blk_slice(slice: &[u8], blknum: u32) -> std::io::Result { - let start = (blknum as usize).checked_mul(PAGE_SZ).unwrap(); - let end = start.checked_add(PAGE_SZ).unwrap(); - if end > slice.len() { - return Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - format!("slice too short, len={} end={}", slice.len(), end), - )); - } - let slice = &slice[start..end]; - let page_sized: &[u8; PAGE_SZ] = slice - .try_into() - .expect("we add PAGE_SZ to start, so the slice must have PAGE_SZ"); - Ok(BlockLease::Slice(page_sized)) - } -} - /// /// A "cursor" for efficiently reading multiple pages from a BlockReader /// diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 44f0fc7ab1..5324e1807d 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -1,13 +1,21 @@ //! Implementation of append-only file data structure //! used to keep in-memory layers spilled on disk. +use crate::assert_u64_eq_usize::{U64IsUsize, UsizeIsU64}; use crate::config::PageServerConf; use crate::context::RequestContext; use crate::page_cache; -use crate::tenant::block_io::{BlockCursor, BlockLease, BlockReader}; -use crate::virtual_file::{self, VirtualFile}; +use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File; +use crate::virtual_file::owned_buffers_io::slice::SliceMutExt; +use crate::virtual_file::owned_buffers_io::util::size_tracking_writer; +use crate::virtual_file::owned_buffers_io::write::Buffer; +use crate::virtual_file::{self, owned_buffers_io, VirtualFile}; +use bytes::BytesMut; use camino::Utf8PathBuf; +use num_traits::Num; use pageserver_api::shard::TenantShardId; +use tokio_epoll_uring::{BoundedBuf, Slice}; +use tracing::error; use std::io; use std::sync::atomic::AtomicU64; @@ -16,12 +24,17 @@ use utils::id::TimelineId; pub struct EphemeralFile { _tenant_shard_id: TenantShardId, _timeline_id: TimelineId, - - rw: page_caching::RW, + page_cache_file_id: page_cache::FileId, + bytes_written: u64, + buffered_writer: owned_buffers_io::write::BufferedWriter< + BytesMut, + size_tracking_writer::Writer, + >, + /// Gate guard is held on as long as we need to do operations in the path (delete on drop) + _gate_guard: utils::sync::gate::GateGuard, } -mod page_caching; -mod zero_padded_read_write; +const TAIL_SZ: usize = 64 * 1024; impl EphemeralFile { pub async fn create( @@ -51,75 +64,178 @@ impl EphemeralFile { ) .await?; + let page_cache_file_id = page_cache::next_file_id(); // XXX get rid, we're not page-caching anymore + Ok(EphemeralFile { _tenant_shard_id: tenant_shard_id, _timeline_id: timeline_id, - rw: page_caching::RW::new(file, gate_guard), + page_cache_file_id, + bytes_written: 0, + buffered_writer: owned_buffers_io::write::BufferedWriter::new( + size_tracking_writer::Writer::new(file), + BytesMut::with_capacity(TAIL_SZ), + ), + _gate_guard: gate_guard, }) } +} +impl Drop for EphemeralFile { + fn drop(&mut self) { + // unlink the file + // we are clear to do this, because we have entered a gate + let path = &self.buffered_writer.as_inner().as_inner().path; + let res = std::fs::remove_file(path); + if let Err(e) = res { + if e.kind() != std::io::ErrorKind::NotFound { + // just never log the not found errors, we cannot do anything for them; on detach + // the tenant directory is already gone. + // + // not found files might also be related to https://github.com/neondatabase/neon/issues/2442 + error!("could not remove ephemeral file '{path}': {e}"); + } + } + } +} + +impl EphemeralFile { pub(crate) fn len(&self) -> u64 { - self.rw.bytes_written() + self.bytes_written } pub(crate) fn page_cache_file_id(&self) -> page_cache::FileId { - self.rw.page_cache_file_id() + self.page_cache_file_id } - /// See [`self::page_caching::RW::load_to_vec`]. pub(crate) async fn load_to_vec(&self, ctx: &RequestContext) -> Result, io::Error> { - self.rw.load_to_vec(ctx).await - } - - pub(crate) async fn read_blk( - &self, - blknum: u32, - ctx: &RequestContext, - ) -> Result { - self.rw.read_blk(blknum, ctx).await - } - - #[cfg(test)] - // This is a test helper: outside of tests, we are always written to via a pre-serialized batch. - pub(crate) async fn write_blob( - &mut self, - srcbuf: &[u8], - ctx: &RequestContext, - ) -> Result { - let pos = self.rw.bytes_written(); - - let mut len_bytes = std::io::Cursor::new(Vec::new()); - crate::tenant::storage_layer::inmemory_layer::SerializedBatch::write_blob_length( - srcbuf.len(), - &mut len_bytes, - ); - let len_bytes = len_bytes.into_inner(); - - // Write the length field - self.rw.write_all_borrowed(&len_bytes, ctx).await?; - - // Write the payload - self.rw.write_all_borrowed(srcbuf, ctx).await?; - - Ok(pos) + let size = self.len().into_usize(); + let vec = Vec::with_capacity(size); + let (slice, nread) = self.read_exact_at_eof_ok(0, vec.slice_full(), ctx).await?; + assert_eq!(nread, size); + let vec = slice.into_inner(); + assert_eq!(vec.len(), nread); + assert_eq!(vec.capacity(), size, "we shouldn't be reallocating"); + Ok(vec) } /// Returns the offset at which the first byte of the input was written, for use /// in constructing indices over the written value. + /// + /// Panics if the write is short because there's no way we can recover from that. + /// TODO: make upstack handle this as an error. pub(crate) async fn write_raw( &mut self, srcbuf: &[u8], ctx: &RequestContext, - ) -> Result { - let pos = self.rw.bytes_written(); + ) -> std::io::Result { + let pos = self.bytes_written; + + let new_bytes_written = pos.checked_add(srcbuf.len().into_u64()).ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!( + "write would grow EphemeralFile beyond u64::MAX: len={pos} writen={srcbuf_len}", + srcbuf_len = srcbuf.len(), + ), + ) + })?; // Write the payload - self.rw.write_all_borrowed(srcbuf, ctx).await?; + let nwritten = self + .buffered_writer + .write_buffered_borrowed(srcbuf, ctx) + .await?; + assert_eq!( + nwritten, + srcbuf.len(), + "buffered writer has no short writes" + ); + + self.bytes_written = new_bytes_written; Ok(pos) } } +impl super::storage_layer::inmemory_layer::vectored_dio_read::File for EphemeralFile { + async fn read_exact_at_eof_ok<'a, 'b, B: tokio_epoll_uring::IoBufMut + Send>( + &'b self, + start: u64, + dst: tokio_epoll_uring::Slice, + ctx: &'a RequestContext, + ) -> std::io::Result<(tokio_epoll_uring::Slice, usize)> { + let file_size_tracking_writer = self.buffered_writer.as_inner(); + let flushed_offset = file_size_tracking_writer.bytes_written(); + + let buffer = self.buffered_writer.inspect_buffer(); + let buffered = &buffer[0..buffer.pending()]; + + let dst_cap = dst.bytes_total().into_u64(); + let end = { + // saturating_add is correct here because the max file size is u64::MAX, so, + // if start + dst.len() > u64::MAX, then we know it will be a short read + let mut end: u64 = start.saturating_add(dst_cap); + if end > self.bytes_written { + end = self.bytes_written; + } + end + }; + + // inclusive, exclusive + #[derive(Debug)] + struct Range(N, N); + impl Range { + fn len(&self) -> N { + if self.0 > self.1 { + N::zero() + } else { + self.1 - self.0 + } + } + } + let written_range = Range(start, std::cmp::min(end, flushed_offset)); + let buffered_range = Range(std::cmp::max(start, flushed_offset), end); + + let dst = if written_range.len() > 0 { + let file: &VirtualFile = file_size_tracking_writer.as_inner(); + let bounds = dst.bounds(); + let slice = file + .read_exact_at(dst.slice(0..written_range.len().into_usize()), start, ctx) + .await?; + Slice::from_buf_bounds(Slice::into_inner(slice), bounds) + } else { + dst + }; + + let dst = if buffered_range.len() > 0 { + let offset_in_buffer = buffered_range + .0 + .checked_sub(flushed_offset) + .unwrap() + .into_usize(); + let to_copy = + &buffered[offset_in_buffer..(offset_in_buffer + buffered_range.len().into_usize())]; + let bounds = dst.bounds(); + let mut view = dst.slice({ + let start = written_range.len().into_usize(); + let end = start + .checked_add(buffered_range.len().into_usize()) + .unwrap(); + start..end + }); + view.as_mut_rust_slice_full_zeroed() + .copy_from_slice(to_copy); + Slice::from_buf_bounds(Slice::into_inner(view), bounds) + } else { + dst + }; + + // TODO: in debug mode, randomize the remaining bytes in `dst` to catch bugs + + Ok((dst, (end - start).into_usize())) + } +} + /// Does the given filename look like an ephemeral file? pub fn is_ephemeral_file(filename: &str) -> bool { if let Some(rest) = filename.strip_prefix("ephemeral-") { @@ -129,19 +245,13 @@ pub fn is_ephemeral_file(filename: &str) -> bool { } } -impl BlockReader for EphemeralFile { - fn block_cursor(&self) -> super::block_io::BlockCursor<'_> { - BlockCursor::new(super::block_io::BlockReaderRef::EphemeralFile(self)) - } -} - #[cfg(test)] mod tests { + use rand::Rng; + use super::*; use crate::context::DownloadBehavior; use crate::task_mgr::TaskKind; - use crate::tenant::block_io::BlockReaderRef; - use rand::{thread_rng, RngCore}; use std::fs; use std::str::FromStr; @@ -172,69 +282,6 @@ mod tests { Ok((conf, tenant_shard_id, timeline_id, ctx)) } - #[tokio::test] - async fn test_ephemeral_blobs() -> Result<(), io::Error> { - let (conf, tenant_id, timeline_id, ctx) = harness("ephemeral_blobs")?; - - let gate = utils::sync::gate::Gate::default(); - - let entered = gate.enter().unwrap(); - - let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, entered, &ctx).await?; - - let pos_foo = file.write_blob(b"foo", &ctx).await?; - assert_eq!( - b"foo", - file.block_cursor() - .read_blob(pos_foo, &ctx) - .await? - .as_slice() - ); - let pos_bar = file.write_blob(b"bar", &ctx).await?; - assert_eq!( - b"foo", - file.block_cursor() - .read_blob(pos_foo, &ctx) - .await? - .as_slice() - ); - assert_eq!( - b"bar", - file.block_cursor() - .read_blob(pos_bar, &ctx) - .await? - .as_slice() - ); - - let mut blobs = Vec::new(); - for i in 0..10000 { - let data = Vec::from(format!("blob{}", i).as_bytes()); - let pos = file.write_blob(&data, &ctx).await?; - blobs.push((pos, data)); - } - // also test with a large blobs - for i in 0..100 { - let data = format!("blob{}", i).as_bytes().repeat(100); - let pos = file.write_blob(&data, &ctx).await?; - blobs.push((pos, data)); - } - - let cursor = BlockCursor::new(BlockReaderRef::EphemeralFile(&file)); - for (pos, expected) in blobs { - let actual = cursor.read_blob(pos, &ctx).await?; - assert_eq!(actual, expected); - } - - // Test a large blob that spans multiple pages - let mut large_data = vec![0; 20000]; - thread_rng().fill_bytes(&mut large_data); - let pos_large = file.write_blob(&large_data, &ctx).await?; - let result = file.block_cursor().read_blob(pos_large, &ctx).await?; - assert_eq!(result, large_data); - - Ok(()) - } - #[tokio::test] async fn ephemeral_file_holds_gate_open() { const FOREVER: std::time::Duration = std::time::Duration::from_secs(5); @@ -268,4 +315,151 @@ mod tests { .expect("closing completes right away") .expect("closing does not panic"); } + + #[tokio::test] + async fn test_ephemeral_file_basics() { + let (conf, tenant_id, timeline_id, ctx) = harness("test_ephemeral_file_basics").unwrap(); + + let gate = utils::sync::gate::Gate::default(); + + let mut file = + EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx) + .await + .unwrap(); + + let cap = file.buffered_writer.inspect_buffer().capacity(); + + let write_nbytes = cap + cap / 2; + + let content: Vec = rand::thread_rng() + .sample_iter(rand::distributions::Standard) + .take(write_nbytes) + .collect(); + + let mut value_offsets = Vec::new(); + for i in 0..write_nbytes { + let off = file.write_raw(&content[i..i + 1], &ctx).await.unwrap(); + value_offsets.push(off); + } + + assert!(file.len() as usize == write_nbytes); + for i in 0..write_nbytes { + assert_eq!(value_offsets[i], i.into_u64()); + let buf = Vec::with_capacity(1); + let (buf_slice, nread) = file + .read_exact_at_eof_ok(i.into_u64(), buf.slice_full(), &ctx) + .await + .unwrap(); + let buf = buf_slice.into_inner(); + assert_eq!(nread, 1); + assert_eq!(&buf, &content[i..i + 1]); + } + + let file_contents = + std::fs::read(&file.buffered_writer.as_inner().as_inner().path).unwrap(); + assert_eq!(file_contents, &content[0..cap]); + + let buffer_contents = file.buffered_writer.inspect_buffer(); + assert_eq!(buffer_contents, &content[cap..write_nbytes]); + } + + #[tokio::test] + async fn test_flushes_do_happen() { + let (conf, tenant_id, timeline_id, ctx) = harness("test_flushes_do_happen").unwrap(); + + let gate = utils::sync::gate::Gate::default(); + + let mut file = + EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx) + .await + .unwrap(); + + let cap = file.buffered_writer.inspect_buffer().capacity(); + + let content: Vec = rand::thread_rng() + .sample_iter(rand::distributions::Standard) + .take(cap + cap / 2) + .collect(); + + file.write_raw(&content, &ctx).await.unwrap(); + + // assert the state is as this test expects it to be + assert_eq!( + &file.load_to_vec(&ctx).await.unwrap(), + &content[0..cap + cap / 2] + ); + let md = file + .buffered_writer + .as_inner() + .as_inner() + .path + .metadata() + .unwrap(); + assert_eq!( + md.len(), + cap.into_u64(), + "buffered writer does one write if we write 1.5x buffer capacity" + ); + assert_eq!( + &file.buffered_writer.inspect_buffer()[0..cap / 2], + &content[cap..cap + cap / 2] + ); + } + + #[tokio::test] + async fn test_read_split_across_file_and_buffer() { + // This test exercises the logic on the read path that splits the logical read + // into a read from the flushed part (= the file) and a copy from the buffered writer's buffer. + // + // This test build on the assertions in test_flushes_do_happen + + let (conf, tenant_id, timeline_id, ctx) = + harness("test_read_split_across_file_and_buffer").unwrap(); + + let gate = utils::sync::gate::Gate::default(); + + let mut file = + EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx) + .await + .unwrap(); + + let cap = file.buffered_writer.inspect_buffer().capacity(); + + let content: Vec = rand::thread_rng() + .sample_iter(rand::distributions::Standard) + .take(cap + cap / 2) + .collect(); + + file.write_raw(&content, &ctx).await.unwrap(); + + let test_read = |start: usize, len: usize| { + let file = &file; + let ctx = &ctx; + let content = &content; + async move { + let (buf, nread) = file + .read_exact_at_eof_ok( + start.into_u64(), + Vec::with_capacity(len).slice_full(), + ctx, + ) + .await + .unwrap(); + assert_eq!(nread, len); + assert_eq!(&buf.into_inner(), &content[start..(start + len)]); + } + }; + + // completely within the file range + assert!(20 < cap, "test assumption"); + test_read(10, 10).await; + // border onto edge of file + test_read(cap - 10, 10).await; + // read across file and buffer + test_read(cap - 10, 20).await; + // stay from start of buffer + test_read(cap, 10).await; + // completely within buffer + test_read(cap + 10, 10).await; + } } diff --git a/pageserver/src/tenant/ephemeral_file/page_caching.rs b/pageserver/src/tenant/ephemeral_file/page_caching.rs deleted file mode 100644 index 48926354f1..0000000000 --- a/pageserver/src/tenant/ephemeral_file/page_caching.rs +++ /dev/null @@ -1,153 +0,0 @@ -//! Wrapper around [`super::zero_padded_read_write::RW`] that uses the -//! [`crate::page_cache`] to serve reads that need to go to the underlying [`VirtualFile`]. -//! -//! Subject to removal in - -use crate::context::RequestContext; -use crate::page_cache::{self, PAGE_SZ}; -use crate::tenant::block_io::BlockLease; -use crate::virtual_file::owned_buffers_io::util::size_tracking_writer; -use crate::virtual_file::VirtualFile; - -use std::io::{self}; -use tokio_epoll_uring::BoundedBuf; -use tracing::*; - -use super::zero_padded_read_write; - -/// See module-level comment. -pub struct RW { - page_cache_file_id: page_cache::FileId, - rw: super::zero_padded_read_write::RW>, - /// Gate guard is held on as long as we need to do operations in the path (delete on drop). - _gate_guard: utils::sync::gate::GateGuard, -} - -impl RW { - pub fn new(file: VirtualFile, _gate_guard: utils::sync::gate::GateGuard) -> Self { - let page_cache_file_id = page_cache::next_file_id(); - Self { - page_cache_file_id, - rw: super::zero_padded_read_write::RW::new(size_tracking_writer::Writer::new(file)), - _gate_guard, - } - } - - pub fn page_cache_file_id(&self) -> page_cache::FileId { - self.page_cache_file_id - } - - pub(crate) async fn write_all_borrowed( - &mut self, - srcbuf: &[u8], - ctx: &RequestContext, - ) -> Result { - // It doesn't make sense to proactively fill the page cache on the Pageserver write path - // because Compute is unlikely to access recently written data. - self.rw.write_all_borrowed(srcbuf, ctx).await - } - - pub(crate) fn bytes_written(&self) -> u64 { - self.rw.bytes_written() - } - - /// Load all blocks that can be read via [`Self::read_blk`] into a contiguous memory buffer. - /// - /// This includes the blocks that aren't yet flushed to disk by the internal buffered writer. - /// The last block is zero-padded to [`PAGE_SZ`], so, the returned buffer is always a multiple of [`PAGE_SZ`]. - pub(super) async fn load_to_vec(&self, ctx: &RequestContext) -> Result, io::Error> { - // round up to the next PAGE_SZ multiple, required by blob_io - let size = { - let s = usize::try_from(self.bytes_written()).unwrap(); - if s % PAGE_SZ == 0 { - s - } else { - s.checked_add(PAGE_SZ - (s % PAGE_SZ)).unwrap() - } - }; - let vec = Vec::with_capacity(size); - - // read from disk what we've already flushed - let file_size_tracking_writer = self.rw.as_writer(); - let flushed_range = 0..usize::try_from(file_size_tracking_writer.bytes_written()).unwrap(); - let mut vec = file_size_tracking_writer - .as_inner() - .read_exact_at( - vec.slice(0..(flushed_range.end - flushed_range.start)), - u64::try_from(flushed_range.start).unwrap(), - ctx, - ) - .await? - .into_inner(); - - // copy from in-memory buffer what we haven't flushed yet but would return when accessed via read_blk - let buffered = self.rw.get_tail_zero_padded(); - vec.extend_from_slice(buffered); - assert_eq!(vec.len(), size); - assert_eq!(vec.len() % PAGE_SZ, 0); - Ok(vec) - } - - pub(crate) async fn read_blk( - &self, - blknum: u32, - ctx: &RequestContext, - ) -> Result { - match self.rw.read_blk(blknum).await? { - zero_padded_read_write::ReadResult::NeedsReadFromWriter { writer } => { - let cache = page_cache::get(); - match cache - .read_immutable_buf(self.page_cache_file_id, blknum, ctx) - .await - .map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - // order path before error because error is anyhow::Error => might have many contexts - format!( - "ephemeral file: read immutable page #{}: {}: {:#}", - blknum, - self.rw.as_writer().as_inner().path, - e, - ), - ) - })? { - page_cache::ReadBufResult::Found(guard) => { - return Ok(BlockLease::PageReadGuard(guard)) - } - page_cache::ReadBufResult::NotFound(write_guard) => { - let write_guard = writer - .as_inner() - .read_exact_at_page(write_guard, blknum as u64 * PAGE_SZ as u64, ctx) - .await?; - let read_guard = write_guard.mark_valid(); - return Ok(BlockLease::PageReadGuard(read_guard)); - } - } - } - zero_padded_read_write::ReadResult::ServedFromZeroPaddedMutableTail { buffer } => { - Ok(BlockLease::EphemeralFileMutableTail(buffer)) - } - } - } -} - -impl Drop for RW { - fn drop(&mut self) { - // There might still be pages in the [`crate::page_cache`] for this file. - // We leave them there, [`crate::page_cache::PageCache::find_victim`] will evict them when needed. - - // unlink the file - // we are clear to do this, because we have entered a gate - let path = &self.rw.as_writer().as_inner().path; - let res = std::fs::remove_file(path); - if let Err(e) = res { - if e.kind() != std::io::ErrorKind::NotFound { - // just never log the not found errors, we cannot do anything for them; on detach - // the tenant directory is already gone. - // - // not found files might also be related to https://github.com/neondatabase/neon/issues/2442 - error!("could not remove ephemeral file '{path}': {e}"); - } - } - } -} diff --git a/pageserver/src/tenant/ephemeral_file/zero_padded_read_write.rs b/pageserver/src/tenant/ephemeral_file/zero_padded_read_write.rs deleted file mode 100644 index fe310acab8..0000000000 --- a/pageserver/src/tenant/ephemeral_file/zero_padded_read_write.rs +++ /dev/null @@ -1,145 +0,0 @@ -//! The heart of how [`super::EphemeralFile`] does its reads and writes. -//! -//! # Writes -//! -//! [`super::EphemeralFile`] writes small, borrowed buffers using [`RW::write_all_borrowed`]. -//! The [`RW`] batches these into [`TAIL_SZ`] bigger writes, using [`owned_buffers_io::write::BufferedWriter`]. -//! -//! # Reads -//! -//! [`super::EphemeralFile`] always reads full [`PAGE_SZ`]ed blocks using [`RW::read_blk`]. -//! -//! The [`RW`] serves these reads either from the buffered writer's in-memory buffer -//! or redirects the caller to read from the underlying [`OwnedAsyncWriter`] -//! if the read is for the prefix that has already been flushed. -//! -//! # Current Usage -//! -//! The current user of this module is [`super::page_caching::RW`]. - -mod zero_padded; - -use crate::{ - context::RequestContext, - page_cache::PAGE_SZ, - virtual_file::owned_buffers_io::{ - self, - write::{Buffer, OwnedAsyncWriter}, - }, -}; - -const TAIL_SZ: usize = 64 * 1024; - -/// See module-level comment. -pub struct RW { - buffered_writer: owned_buffers_io::write::BufferedWriter< - zero_padded::Buffer, - owned_buffers_io::util::size_tracking_writer::Writer, - >, -} - -pub enum ReadResult<'a, W> { - NeedsReadFromWriter { writer: &'a W }, - ServedFromZeroPaddedMutableTail { buffer: &'a [u8; PAGE_SZ] }, -} - -impl RW -where - W: OwnedAsyncWriter, -{ - pub fn new(writer: W) -> Self { - let bytes_flushed_tracker = - owned_buffers_io::util::size_tracking_writer::Writer::new(writer); - let buffered_writer = owned_buffers_io::write::BufferedWriter::new( - bytes_flushed_tracker, - zero_padded::Buffer::default(), - ); - Self { buffered_writer } - } - - pub(crate) fn as_writer(&self) -> &W { - self.buffered_writer.as_inner().as_inner() - } - - pub async fn write_all_borrowed( - &mut self, - buf: &[u8], - ctx: &RequestContext, - ) -> std::io::Result { - self.buffered_writer.write_buffered_borrowed(buf, ctx).await - } - - pub fn bytes_written(&self) -> u64 { - let flushed_offset = self.buffered_writer.as_inner().bytes_written(); - let buffer: &zero_padded::Buffer = self.buffered_writer.inspect_buffer(); - flushed_offset + u64::try_from(buffer.pending()).unwrap() - } - - /// Get a slice of all blocks that [`Self::read_blk`] would return as [`ReadResult::ServedFromZeroPaddedMutableTail`]. - pub fn get_tail_zero_padded(&self) -> &[u8] { - let buffer: &zero_padded::Buffer = self.buffered_writer.inspect_buffer(); - let buffer_written_up_to = buffer.pending(); - // pad to next page boundary - let read_up_to = if buffer_written_up_to % PAGE_SZ == 0 { - buffer_written_up_to - } else { - buffer_written_up_to - .checked_add(PAGE_SZ - (buffer_written_up_to % PAGE_SZ)) - .unwrap() - }; - &buffer.as_zero_padded_slice()[0..read_up_to] - } - - pub(crate) async fn read_blk(&self, blknum: u32) -> Result, std::io::Error> { - let flushed_offset = self.buffered_writer.as_inner().bytes_written(); - let buffer: &zero_padded::Buffer = self.buffered_writer.inspect_buffer(); - let buffered_offset = flushed_offset + u64::try_from(buffer.pending()).unwrap(); - let read_offset = (blknum as u64) * (PAGE_SZ as u64); - - // The trailing page ("block") might only be partially filled, - // yet the blob_io code relies on us to return a full PAGE_SZed slice anyway. - // Moreover, it has to be zero-padded, because when we still had - // a write-back page cache, it provided pre-zeroed pages, and blob_io came to rely on it. - // DeltaLayer probably has the same issue, not sure why it needs no special treatment. - // => check here that the read doesn't go beyond this potentially trailing - // => the zero-padding is done in the `else` branch below - let blocks_written = if buffered_offset % (PAGE_SZ as u64) == 0 { - buffered_offset / (PAGE_SZ as u64) - } else { - (buffered_offset / (PAGE_SZ as u64)) + 1 - }; - if (blknum as u64) >= blocks_written { - return Err(std::io::Error::new(std::io::ErrorKind::Other, anyhow::anyhow!("read past end of ephemeral_file: read=0x{read_offset:x} buffered=0x{buffered_offset:x} flushed=0x{flushed_offset}"))); - } - - // assertions for the `if-else` below - assert_eq!( - flushed_offset % (TAIL_SZ as u64), 0, - "we only use write_buffered_borrowed to write to the buffered writer, so it's guaranteed that flushes happen buffer.cap()-sized chunks" - ); - assert_eq!( - flushed_offset % (PAGE_SZ as u64), - 0, - "the logic below can't handle if the page is spread across the flushed part and the buffer" - ); - - if read_offset < flushed_offset { - assert!(read_offset + (PAGE_SZ as u64) <= flushed_offset); - Ok(ReadResult::NeedsReadFromWriter { - writer: self.as_writer(), - }) - } else { - let read_offset_in_buffer = read_offset - .checked_sub(flushed_offset) - .expect("would have taken `if` branch instead of this one"); - let read_offset_in_buffer = usize::try_from(read_offset_in_buffer).unwrap(); - let zero_padded_slice = buffer.as_zero_padded_slice(); - let page = &zero_padded_slice[read_offset_in_buffer..(read_offset_in_buffer + PAGE_SZ)]; - Ok(ReadResult::ServedFromZeroPaddedMutableTail { - buffer: page - .try_into() - .expect("the slice above got it as page-size slice"), - }) - } - } -} diff --git a/pageserver/src/tenant/ephemeral_file/zero_padded_read_write/zero_padded.rs b/pageserver/src/tenant/ephemeral_file/zero_padded_read_write/zero_padded.rs deleted file mode 100644 index 2dc0277638..0000000000 --- a/pageserver/src/tenant/ephemeral_file/zero_padded_read_write/zero_padded.rs +++ /dev/null @@ -1,110 +0,0 @@ -//! A [`crate::virtual_file::owned_buffers_io::write::Buffer`] whose -//! unwritten range is guaranteed to be zero-initialized. -//! This is used by [`crate::tenant::ephemeral_file::zero_padded_read_write::RW::read_blk`] -//! to serve page-sized reads of the trailing page when the trailing page has only been partially filled. - -use std::mem::MaybeUninit; - -use crate::virtual_file::owned_buffers_io::io_buf_ext::FullSlice; - -/// See module-level comment. -pub struct Buffer { - allocation: Box<[u8; N]>, - written: usize, -} - -impl Default for Buffer { - fn default() -> Self { - Self { - allocation: Box::new( - // SAFETY: zeroed memory is a valid [u8; N] - unsafe { MaybeUninit::zeroed().assume_init() }, - ), - written: 0, - } - } -} - -impl Buffer { - #[inline(always)] - fn invariants(&self) { - // don't check by default, unoptimized is too expensive even for debug mode - if false { - debug_assert!(self.written <= N, "{}", self.written); - debug_assert!(self.allocation[self.written..N].iter().all(|v| *v == 0)); - } - } - - pub fn as_zero_padded_slice(&self) -> &[u8; N] { - &self.allocation - } -} - -impl crate::virtual_file::owned_buffers_io::write::Buffer for Buffer { - type IoBuf = Self; - - fn cap(&self) -> usize { - self.allocation.len() - } - - fn extend_from_slice(&mut self, other: &[u8]) { - self.invariants(); - let remaining = self.allocation.len() - self.written; - if other.len() > remaining { - panic!("calling extend_from_slice() with insufficient remaining capacity"); - } - self.allocation[self.written..(self.written + other.len())].copy_from_slice(other); - self.written += other.len(); - self.invariants(); - } - - fn pending(&self) -> usize { - self.written - } - - fn flush(self) -> FullSlice { - self.invariants(); - let written = self.written; - FullSlice::must_new(tokio_epoll_uring::BoundedBuf::slice(self, 0..written)) - } - - fn reuse_after_flush(iobuf: Self::IoBuf) -> Self { - let Self { - mut allocation, - written, - } = iobuf; - allocation[0..written].fill(0); - let new = Self { - allocation, - written: 0, - }; - new.invariants(); - new - } -} - -/// We have this trait impl so that the `flush` method in the `Buffer` impl above can produce a -/// [`tokio_epoll_uring::BoundedBuf::slice`] of the [`Self::written`] range of the data. -/// -/// Remember that bytes_init is generally _not_ a tracker of the amount -/// of valid data in the io buffer; we use `Slice` for that. -/// The `IoBuf` is _only_ for keeping track of uninitialized memory, a bit like MaybeUninit. -/// -/// SAFETY: -/// -/// The [`Self::allocation`] is stable becauses boxes are stable. -/// The memory is zero-initialized, so, bytes_init is always N. -unsafe impl tokio_epoll_uring::IoBuf for Buffer { - fn stable_ptr(&self) -> *const u8 { - self.allocation.as_ptr() - } - - fn bytes_init(&self) -> usize { - // Yes, N, not self.written; Read the full comment of this impl block! - N - } - - fn bytes_total(&self) -> usize { - N - } -} diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index f4a2957972..885eb13b29 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -40,7 +40,7 @@ use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT; use crate::tenant::timeline::GetVectoredError; use crate::tenant::vectored_blob_io::{ BlobFlag, MaxVectoredReadBytes, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead, - VectoredReadPlanner, + VectoredReadCoalesceMode, VectoredReadPlanner, }; use crate::tenant::PageReconstructError; use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt}; @@ -65,7 +65,7 @@ use std::os::unix::fs::FileExt; use std::str::FromStr; use std::sync::Arc; use tokio::sync::OnceCell; -use tokio_epoll_uring::IoBufMut; +use tokio_epoll_uring::IoBuf; use tracing::*; use utils::{ @@ -471,7 +471,7 @@ impl DeltaLayerWriterInner { ctx: &RequestContext, ) -> (FullSlice, anyhow::Result<()>) where - Buf: IoBufMut + Send, + Buf: IoBuf + Send, { assert!( self.lsn_range.start <= lsn, @@ -678,7 +678,7 @@ impl DeltaLayerWriter { ctx: &RequestContext, ) -> (FullSlice, anyhow::Result<()>) where - Buf: IoBufMut + Send, + Buf: IoBuf + Send, { self.inner .as_mut() @@ -1205,6 +1205,7 @@ impl DeltaLayerInner { let mut prev: Option<(Key, Lsn, BlobRef)> = None; let mut read_builder: Option = None; + let read_mode = VectoredReadCoalesceMode::get(); let max_read_size = self .max_vectored_read_bytes @@ -1253,6 +1254,7 @@ impl DeltaLayerInner { offsets.end.pos(), meta, max_read_size, + read_mode, )) } } else { @@ -2295,7 +2297,7 @@ pub(crate) mod test { // every key should be a batch b/c the value is larger than max_read_size assert_eq!(iter.key_values_batch.len(), 1); } else { - assert_eq!(iter.key_values_batch.len(), batch_size); + assert!(iter.key_values_batch.len() <= batch_size); } if num_items >= N { break; diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 3cb2b1c83a..4c22541e02 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -1381,7 +1381,7 @@ mod test { // every key should be a batch b/c the value is larger than max_read_size assert_eq!(iter.key_values_batch.len(), 1); } else { - assert_eq!(iter.key_values_batch.len(), batch_size); + assert!(iter.key_values_batch.len() <= batch_size); } if num_items >= N { break; diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index a71b4dd83b..f31ab4b1e8 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -4,23 +4,23 @@ //! held in an ephemeral file, not in memory. The metadata for each page version, i.e. //! its position in the file, is kept in memory, though. //! +use crate::assert_u64_eq_usize::{u64_to_usize, U64IsUsize, UsizeIsU64}; use crate::config::PageServerConf; use crate::context::{PageContentKind, RequestContext, RequestContextBuilder}; -use crate::page_cache::PAGE_SZ; use crate::repository::{Key, Value}; -use crate::tenant::block_io::{BlockCursor, BlockReader, BlockReaderRef}; use crate::tenant::ephemeral_file::EphemeralFile; use crate::tenant::timeline::GetVectoredError; use crate::tenant::PageReconstructError; use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt; use crate::{l0_flush, page_cache}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; +use bytes::Bytes; use camino::Utf8PathBuf; use pageserver_api::key::CompactKey; use pageserver_api::keyspace::KeySpace; use pageserver_api::models::InMemoryLayerInfo; use pageserver_api::shard::TenantShardId; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::sync::{Arc, OnceLock}; use std::time::Instant; use tracing::*; @@ -39,6 +39,8 @@ use super::{ DeltaLayerWriter, PersistentLayerDesc, ValueReconstructSituation, ValuesReconstructState, }; +pub(crate) mod vectored_dio_read; + #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub(crate) struct InMemoryLayerFileId(page_cache::FileId); @@ -78,9 +80,9 @@ impl std::fmt::Debug for InMemoryLayer { pub struct InMemoryLayerInner { /// All versions of all pages in the layer are kept here. Indexed - /// by block number and LSN. The value is an offset into the + /// by block number and LSN. The [`IndexEntry`] is an offset into the /// ephemeral file where the page version is stored. - index: BTreeMap>, + index: BTreeMap>, /// The values are stored in a serialized format in this file. /// Each serialized Value is preceded by a 'u32' length field. @@ -90,6 +92,154 @@ pub struct InMemoryLayerInner { resource_units: GlobalResourceUnits, } +/// Support the same max blob length as blob_io, because ultimately +/// all the InMemoryLayer contents end up being written into a delta layer, +/// using the [`crate::tenant::blob_io`]. +const MAX_SUPPORTED_BLOB_LEN: usize = crate::tenant::blob_io::MAX_SUPPORTED_BLOB_LEN; +const MAX_SUPPORTED_BLOB_LEN_BITS: usize = { + let trailing_ones = MAX_SUPPORTED_BLOB_LEN.trailing_ones() as usize; + let leading_zeroes = MAX_SUPPORTED_BLOB_LEN.leading_zeros() as usize; + assert!(trailing_ones + leading_zeroes == std::mem::size_of::() * 8); + trailing_ones +}; + +/// See [`InMemoryLayerInner::index`]. +/// +/// For memory efficiency, the data is packed into a u64. +/// +/// Layout: +/// - 1 bit: `will_init` +/// - [`MAX_SUPPORTED_BLOB_LEN_BITS`]: `len` +/// - [`MAX_SUPPORTED_POS_BITS`]: `pos` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct IndexEntry(u64); + +impl IndexEntry { + /// See [`Self::MAX_SUPPORTED_POS`]. + const MAX_SUPPORTED_POS_BITS: usize = { + let remainder = 64 - 1 - MAX_SUPPORTED_BLOB_LEN_BITS; + if remainder < 32 { + panic!("pos can be u32 as per type system, support that"); + } + remainder + }; + /// The maximum supported blob offset that can be represented by [`Self`]. + /// See also [`Self::validate_checkpoint_distance`]. + const MAX_SUPPORTED_POS: usize = (1 << Self::MAX_SUPPORTED_POS_BITS) - 1; + + // Layout + const WILL_INIT_RANGE: Range = 0..1; + const LEN_RANGE: Range = + Self::WILL_INIT_RANGE.end..Self::WILL_INIT_RANGE.end + MAX_SUPPORTED_BLOB_LEN_BITS; + const POS_RANGE: Range = + Self::LEN_RANGE.end..Self::LEN_RANGE.end + Self::MAX_SUPPORTED_POS_BITS; + const _ASSERT: () = { + if Self::POS_RANGE.end != 64 { + panic!("we don't want undefined bits for our own sanity") + } + }; + + /// Fails if and only if the offset or length encoded in `arg` is too large to be represented by [`Self`]. + /// + /// The only reason why that can happen in the system is if the [`InMemoryLayer`] grows too long. + /// The [`InMemoryLayer`] size is determined by the checkpoint distance, enforced by [`crate::tenant::Timeline::should_roll`]. + /// + /// Thus, to avoid failure of this function, whenever we start up and/or change checkpoint distance, + /// call [`Self::validate_checkpoint_distance`] with the new checkpoint distance value. + /// + /// TODO: this check should happen ideally at config parsing time (and in the request handler when a change to checkpoint distance is requested) + /// When cleaning this up, also look into the s3 max file size check that is performed in delta layer writer. + #[inline(always)] + fn new(arg: IndexEntryNewArgs) -> anyhow::Result { + let IndexEntryNewArgs { + base_offset, + batch_offset, + len, + will_init, + } = arg; + + let pos = base_offset + .checked_add(batch_offset) + .ok_or_else(|| anyhow::anyhow!("base_offset + batch_offset overflows u64: base_offset={base_offset} batch_offset={batch_offset}"))?; + + if pos.into_usize() > Self::MAX_SUPPORTED_POS { + anyhow::bail!( + "base_offset+batch_offset exceeds the maximum supported value: base_offset={base_offset} batch_offset={batch_offset} (+)={pos} max={max}", + max = Self::MAX_SUPPORTED_POS + ); + } + + if len > MAX_SUPPORTED_BLOB_LEN { + anyhow::bail!( + "len exceeds the maximum supported length: len={len} max={MAX_SUPPORTED_BLOB_LEN}", + ); + } + + let mut data: u64 = 0; + use bit_field::BitField; + data.set_bits(Self::WILL_INIT_RANGE, if will_init { 1 } else { 0 }); + data.set_bits(Self::LEN_RANGE, len.into_u64()); + data.set_bits(Self::POS_RANGE, pos); + + Ok(Self(data)) + } + + #[inline(always)] + fn unpack(&self) -> IndexEntryUnpacked { + use bit_field::BitField; + IndexEntryUnpacked { + will_init: self.0.get_bits(Self::WILL_INIT_RANGE) != 0, + len: self.0.get_bits(Self::LEN_RANGE), + pos: self.0.get_bits(Self::POS_RANGE), + } + } + + /// See [`Self::new`]. + pub(crate) const fn validate_checkpoint_distance( + checkpoint_distance: u64, + ) -> Result<(), &'static str> { + if checkpoint_distance > Self::MAX_SUPPORTED_POS as u64 { + return Err("exceeds the maximum supported value"); + } + let res = u64_to_usize(checkpoint_distance).checked_add(MAX_SUPPORTED_BLOB_LEN); + if res.is_none() { + return Err( + "checkpoint distance + max supported blob len overflows in-memory addition", + ); + } + + // NB: it is ok for the result of the addition to be larger than MAX_SUPPORTED_POS + + Ok(()) + } + + const _ASSERT_DEFAULT_CHECKPOINT_DISTANCE_IS_VALID: () = { + let res = Self::validate_checkpoint_distance( + crate::tenant::config::defaults::DEFAULT_CHECKPOINT_DISTANCE, + ); + if res.is_err() { + panic!("default checkpoint distance is valid") + } + }; +} + +/// Args to [`IndexEntry::new`]. +#[derive(Clone, Copy)] +struct IndexEntryNewArgs { + base_offset: u64, + batch_offset: u64, + len: usize, + will_init: bool, +} + +/// Unpacked representation of the bitfielded [`IndexEntry`]. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +struct IndexEntryUnpacked { + will_init: bool, + len: u64, + pos: u64, +} + impl std::fmt::Debug for InMemoryLayerInner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InMemoryLayerInner").finish() @@ -276,7 +426,12 @@ impl InMemoryLayer { .build(); let inner = self.inner.read().await; - let reader = inner.file.block_cursor(); + + struct ValueRead { + entry_lsn: Lsn, + read: vectored_dio_read::LogicalRead>, + } + let mut reads: HashMap> = HashMap::new(); for range in keyspace.ranges.iter() { for (key, vec_map) in inner @@ -291,24 +446,62 @@ impl InMemoryLayer { let slice = vec_map.slice_range(lsn_range); - for (entry_lsn, pos) in slice.iter().rev() { - // TODO: this uses the page cache => https://github.com/neondatabase/neon/issues/8183 - let buf = reader.read_blob(*pos, &ctx).await; - if let Err(e) = buf { - reconstruct_state.on_key_error(key, PageReconstructError::from(anyhow!(e))); + for (entry_lsn, index_entry) in slice.iter().rev() { + let IndexEntryUnpacked { + pos, + len, + will_init, + } = index_entry.unpack(); + reads.entry(key).or_default().push(ValueRead { + entry_lsn: *entry_lsn, + read: vectored_dio_read::LogicalRead::new( + pos, + Vec::with_capacity(len as usize), + ), + }); + if will_init { break; } + } + } + } - let value = Value::des(&buf.unwrap()); - if let Err(e) = value { + // Execute the reads. + + let f = vectored_dio_read::execute( + &inner.file, + reads + .iter() + .flat_map(|(_, value_reads)| value_reads.iter().map(|v| &v.read)), + &ctx, + ); + send_future::SendFuture::send(f) // https://github.com/rust-lang/rust/issues/96865 + .await; + + // Process results into the reconstruct state + 'next_key: for (key, value_reads) in reads { + for ValueRead { entry_lsn, read } in value_reads { + match read.into_result().expect("we run execute() above") { + Err(e) => { reconstruct_state.on_key_error(key, PageReconstructError::from(anyhow!(e))); - break; + continue 'next_key; } + Ok(value_buf) => { + let value = Value::des(&value_buf); + if let Err(e) = value { + reconstruct_state + .on_key_error(key, PageReconstructError::from(anyhow!(e))); + continue 'next_key; + } - let key_situation = - reconstruct_state.update_key(&key, *entry_lsn, value.unwrap()); - if key_situation == ValueReconstructSituation::Complete { - break; + let key_situation = + reconstruct_state.update_key(&key, entry_lsn, value.unwrap()); + if key_situation == ValueReconstructSituation::Complete { + // TODO: metric to see if we fetched more values than necessary + continue 'next_key; + } + + // process the next value in the next iteration of the loop } } } @@ -324,8 +517,9 @@ impl InMemoryLayer { struct SerializedBatchOffset { key: CompactKey, lsn: Lsn, - /// offset in bytes from the start of the batch's buffer to the Value's serialized size header. - offset: u64, + // TODO: separate type when we start serde-serializing this value, to avoid coupling + // in-memory representation to serialization format. + index_entry: IndexEntry, } pub struct SerializedBatch { @@ -340,30 +534,10 @@ pub struct SerializedBatch { } impl SerializedBatch { - /// Write a blob length in the internal format of the EphemeralFile - pub(crate) fn write_blob_length(len: usize, cursor: &mut std::io::Cursor>) { - use std::io::Write; - - if len < 0x80 { - // short one-byte length header - let len_buf = [len as u8]; - - cursor - .write_all(&len_buf) - .expect("Writing to Vec is infallible"); - } else { - let mut len_buf = u32::to_be_bytes(len as u32); - len_buf[0] |= 0x80; - cursor - .write_all(&len_buf) - .expect("Writing to Vec is infallible"); - } - } - - pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> Self { + pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> anyhow::Result { // Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by // [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`] - let buffer_size = batch.iter().map(|i| i.2).sum::() + 4 * batch.len(); + let buffer_size = batch.iter().map(|i| i.2).sum::(); let mut cursor = std::io::Cursor::new(Vec::::with_capacity(buffer_size)); let mut offsets: Vec = Vec::with_capacity(batch.len()); @@ -371,14 +545,19 @@ impl SerializedBatch { for (key, lsn, val_ser_size, val) in batch { let relative_off = cursor.position(); - Self::write_blob_length(val_ser_size, &mut cursor); val.ser_into(&mut cursor) .expect("Writing into in-memory buffer is infallible"); offsets.push(SerializedBatchOffset { key, lsn, - offset: relative_off, + index_entry: IndexEntry::new(IndexEntryNewArgs { + base_offset: 0, + batch_offset: relative_off, + len: val_ser_size, + will_init: val.will_init(), + }) + .context("higher-level code ensures that values are within supported ranges")?, }); max_lsn = std::cmp::max(max_lsn, lsn); } @@ -388,11 +567,11 @@ impl SerializedBatch { // Assert that we didn't do any extra allocations while building buffer. debug_assert!(buffer.len() <= buffer_size); - Self { + Ok(Self { raw: buffer, offsets, max_lsn, - } + }) } } @@ -456,44 +635,69 @@ impl InMemoryLayer { }) } - // Write path. + /// Write path. + /// + /// Errors are not retryable, the [`InMemoryLayer`] must be discarded, and not be read from. + /// The reason why it's not retryable is that the [`EphemeralFile`] writes are not retryable. + /// TODO: it can be made retryable if we aborted the process on EphemeralFile write errors. pub async fn put_batch( &self, serialized_batch: SerializedBatch, ctx: &RequestContext, - ) -> Result<()> { + ) -> anyhow::Result<()> { let mut inner = self.inner.write().await; self.assert_writable(); - let base_off = { - inner - .file - .write_raw( - &serialized_batch.raw, - &RequestContextBuilder::extend(ctx) - .page_content_kind(PageContentKind::InMemoryLayer) - .build(), - ) - .await? - }; + let base_offset = inner.file.len(); + let SerializedBatch { + raw, + mut offsets, + max_lsn: _, + } = serialized_batch; + + // Add the base_offset to the batch's index entries which are relative to the batch start. + for offset in &mut offsets { + let IndexEntryUnpacked { + will_init, + len, + pos, + } = offset.index_entry.unpack(); + offset.index_entry = IndexEntry::new(IndexEntryNewArgs { + base_offset, + batch_offset: pos, + len: len.into_usize(), + will_init, + })?; + } + + // Write the batch to the file + inner.file.write_raw(&raw, ctx).await?; + let new_size = inner.file.len(); + let expected_new_len = base_offset + .checked_add(raw.len().into_u64()) + // write_raw would error if we were to overflow u64. + // also IndexEntry and higher levels in + //the code don't allow the file to grow that large + .unwrap(); + assert_eq!(new_size, expected_new_len); + + // Update the index with the new entries for SerializedBatchOffset { key, lsn, - offset: relative_off, - } in serialized_batch.offsets + index_entry, + } in offsets { - let off = base_off + relative_off; let vec_map = inner.index.entry(key).or_default(); - let old = vec_map.append_or_update_last(lsn, off).unwrap().0; + let old = vec_map.append_or_update_last(lsn, index_entry).unwrap().0; if old.is_some() { // We already had an entry for this LSN. That's odd.. warn!("Key {} at {} already exists", key, lsn); } } - let size = inner.file.len(); - inner.resource_units.maybe_publish_size(size); + inner.resource_units.maybe_publish_size(new_size); Ok(()) } @@ -537,7 +741,7 @@ impl InMemoryLayer { { let inner = self.inner.write().await; for vec_map in inner.index.values() { - for (lsn, _pos) in vec_map.as_slice() { + for (lsn, _) in vec_map.as_slice() { assert!(*lsn < end_lsn); } } @@ -601,36 +805,23 @@ impl InMemoryLayer { match l0_flush_global_state { l0_flush::Inner::Direct { .. } => { let file_contents: Vec = inner.file.load_to_vec(ctx).await?; - assert_eq!( - file_contents.len() % PAGE_SZ, - 0, - "needed by BlockReaderRef::Slice" - ); - assert_eq!(file_contents.len(), { - let written = usize::try_from(inner.file.len()).unwrap(); - if written % PAGE_SZ == 0 { - written - } else { - written.checked_add(PAGE_SZ - (written % PAGE_SZ)).unwrap() - } - }); - let cursor = BlockCursor::new(BlockReaderRef::Slice(&file_contents)); - - let mut buf = Vec::new(); + let file_contents = Bytes::from(file_contents); for (key, vec_map) in inner.index.iter() { // Write all page versions - for (lsn, pos) in vec_map.as_slice() { - // TODO: once we have blob lengths in the in-memory index, we can - // 1. get rid of the blob_io / BlockReaderRef::Slice business and - // 2. load the file contents into a Bytes and - // 3. the use `Bytes::slice` to get the `buf` that is our blob - // 4. pass that `buf` into `put_value_bytes` - // => https://github.com/neondatabase/neon/issues/8183 - cursor.read_blob_into_buf(*pos, &mut buf, ctx).await?; - let will_init = Value::des(&buf)?.will_init(); - let (tmp, res) = delta_layer_writer + for (lsn, entry) in vec_map + .as_slice() + .iter() + .map(|(lsn, entry)| (lsn, entry.unpack())) + { + let IndexEntryUnpacked { + pos, + len, + will_init, + } = entry; + let buf = Bytes::slice(&file_contents, pos as usize..(pos + len) as usize); + let (_buf, res) = delta_layer_writer .put_value_bytes( Key::from_compact(*key), *lsn, @@ -640,7 +831,6 @@ impl InMemoryLayer { ) .await; res?; - buf = tmp.into_raw_slice().into_inner(); } } } @@ -662,3 +852,134 @@ impl InMemoryLayer { Ok(Some((desc, path))) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_index_entry() { + const MAX_SUPPORTED_POS: usize = IndexEntry::MAX_SUPPORTED_POS; + use IndexEntryNewArgs as Args; + use IndexEntryUnpacked as Unpacked; + + let roundtrip = |args, expect: Unpacked| { + let res = IndexEntry::new(args).expect("this tests expects no errors"); + let IndexEntryUnpacked { + will_init, + len, + pos, + } = res.unpack(); + assert_eq!(will_init, expect.will_init); + assert_eq!(len, expect.len); + assert_eq!(pos, expect.pos); + }; + + // basic roundtrip + for pos in [0, MAX_SUPPORTED_POS] { + for len in [0, MAX_SUPPORTED_BLOB_LEN] { + for will_init in [true, false] { + let expect = Unpacked { + will_init, + len: len.into_u64(), + pos: pos.into_u64(), + }; + roundtrip( + Args { + will_init, + base_offset: pos.into_u64(), + batch_offset: 0, + len, + }, + expect, + ); + roundtrip( + Args { + will_init, + base_offset: 0, + batch_offset: pos.into_u64(), + len, + }, + expect, + ); + } + } + } + + // too-large len + let too_large = Args { + will_init: false, + len: MAX_SUPPORTED_BLOB_LEN + 1, + base_offset: 0, + batch_offset: 0, + }; + assert!(IndexEntry::new(too_large).is_err()); + + // too-large pos + { + let too_large = Args { + will_init: false, + len: 0, + base_offset: MAX_SUPPORTED_POS.into_u64() + 1, + batch_offset: 0, + }; + assert!(IndexEntry::new(too_large).is_err()); + let too_large = Args { + will_init: false, + len: 0, + base_offset: 0, + batch_offset: MAX_SUPPORTED_POS.into_u64() + 1, + }; + assert!(IndexEntry::new(too_large).is_err()); + } + + // too large (base_offset + batch_offset) + { + let too_large = Args { + will_init: false, + len: 0, + base_offset: MAX_SUPPORTED_POS.into_u64(), + batch_offset: 1, + }; + assert!(IndexEntry::new(too_large).is_err()); + let too_large = Args { + will_init: false, + len: 0, + base_offset: MAX_SUPPORTED_POS.into_u64() - 1, + batch_offset: MAX_SUPPORTED_POS.into_u64() - 1, + }; + assert!(IndexEntry::new(too_large).is_err()); + } + + // valid special cases + // - area past the max supported pos that is accessible by len + for len in [1, MAX_SUPPORTED_BLOB_LEN] { + roundtrip( + Args { + will_init: false, + len, + base_offset: MAX_SUPPORTED_POS.into_u64(), + batch_offset: 0, + }, + Unpacked { + will_init: false, + len: len as u64, + pos: MAX_SUPPORTED_POS.into_u64(), + }, + ); + roundtrip( + Args { + will_init: false, + len, + base_offset: 0, + batch_offset: MAX_SUPPORTED_POS.into_u64(), + }, + Unpacked { + will_init: false, + len: len as u64, + pos: MAX_SUPPORTED_POS.into_u64(), + }, + ); + } + } +} diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs new file mode 100644 index 0000000000..0683e15659 --- /dev/null +++ b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs @@ -0,0 +1,937 @@ +use std::{ + collections::BTreeMap, + sync::{Arc, RwLock}, +}; + +use itertools::Itertools; +use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice}; + +use crate::{ + assert_u64_eq_usize::{U64IsUsize, UsizeIsU64}, + context::RequestContext, +}; + +/// The file interface we require. At runtime, this is a [`crate::tenant::ephemeral_file::EphemeralFile`]. +pub trait File: Send { + /// Attempt to read the bytes in `self` in range `[start,start+dst.bytes_total())` + /// and return the number of bytes read (let's call it `nread`). + /// The bytes read are placed in `dst`, i.e., `&dst[..nread]` will contain the read bytes. + /// + /// The only reason why the read may be short (i.e., `nread != dst.bytes_total()`) + /// is if the file is shorter than `start+dst.len()`. + /// + /// This is unlike [`std::os::unix::fs::FileExt::read_exact_at`] which returns an + /// [`std::io::ErrorKind::UnexpectedEof`] error if the file is shorter than `start+dst.len()`. + /// + /// No guarantees are made about the remaining bytes in `dst` in case of a short read. + async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>( + &'b self, + start: u64, + dst: Slice, + ctx: &'a RequestContext, + ) -> std::io::Result<(Slice, usize)>; +} + +/// A logical read from [`File`]. See [`Self::new`]. +pub struct LogicalRead { + pos: u64, + state: RwLockRefCell>, +} + +enum LogicalReadState { + NotStarted(B), + Ongoing(B), + Ok(B), + Error(Arc), + Undefined, +} + +impl LogicalRead { + /// Create a new [`LogicalRead`] from [`File`] of the data in the file in range `[ pos, pos + buf.cap() )`. + pub fn new(pos: u64, buf: B) -> Self { + Self { + pos, + state: RwLockRefCell::new(LogicalReadState::NotStarted(buf)), + } + } + pub fn into_result(self) -> Option>> { + match self.state.into_inner() { + LogicalReadState::Ok(buf) => Some(Ok(buf)), + LogicalReadState::Error(e) => Some(Err(e)), + LogicalReadState::NotStarted(_) | LogicalReadState::Ongoing(_) => None, + LogicalReadState::Undefined => unreachable!(), + } + } +} + +/// The buffer into which a [`LogicalRead`] result is placed. +pub trait Buffer: std::ops::Deref { + /// Immutable. + fn cap(&self) -> usize; + /// Changes only through [`Self::extend_from_slice`]. + fn len(&self) -> usize; + /// Panics if the total length would exceed the initialized capacity. + fn extend_from_slice(&mut self, src: &[u8]); +} + +/// The minimum alignment and size requirement for disk offsets and memory buffer size for direct IO. +const DIO_CHUNK_SIZE: usize = 512; + +/// If multiple chunks need to be read, merge adjacent chunk reads into batches of max size `MAX_CHUNK_BATCH_SIZE`. +/// (The unit is the number of chunks.) +const MAX_CHUNK_BATCH_SIZE: usize = { + let desired = 128 * 1024; // 128k + if desired % DIO_CHUNK_SIZE != 0 { + panic!("MAX_CHUNK_BATCH_SIZE must be a multiple of DIO_CHUNK_SIZE") + // compile-time error + } + desired / DIO_CHUNK_SIZE +}; + +/// Execute the given logical `reads` against `file`. +/// The results are placed in the buffers of the [`LogicalRead`]s. +/// Retrieve the results by calling [`LogicalRead::into_result`] on each [`LogicalRead`]. +/// +/// The [`LogicalRead`]s must be freshly created using [`LogicalRead::new`] when calling this function. +/// Otherwise, this function panics. +pub async fn execute<'a, I, F, B>(file: &F, reads: I, ctx: &RequestContext) +where + I: IntoIterator>, + F: File, + B: Buffer + IoBufMut + Send, +{ + // Terminology: + // logical read = a request to read an arbitrary range of bytes from `file`; byte-level granularity + // chunk = we conceptually divide up the byte range of `file` into DIO_CHUNK_SIZEs ranges + // interest = a range within a chunk that a logical read is interested in; one logical read gets turned into many interests + // physical read = the read request we're going to issue to the OS; covers a range of chunks; chunk-level granularity + + // Preserve a copy of the logical reads for debug assertions at the end + #[cfg(debug_assertions)] + let (reads, assert_logical_reads) = { + let (reads, assert) = reads.into_iter().tee(); + (reads, Some(Vec::from_iter(assert))) + }; + #[cfg(not(debug_assertions))] + let (reads, assert_logical_reads): (_, Option>>) = (reads, None); + + // Plan which parts of which chunks need to be appended to which buffer + let mut by_chunk: BTreeMap>> = BTreeMap::new(); + struct Interest<'a, B: Buffer> { + logical_read: &'a LogicalRead, + offset_in_chunk: u64, + len: u64, + } + for logical_read in reads { + let LogicalRead { pos, state } = logical_read; + let mut state = state.borrow_mut(); + + // transition from NotStarted to Ongoing + let cur = std::mem::replace(&mut *state, LogicalReadState::Undefined); + let req_len = match cur { + LogicalReadState::NotStarted(buf) => { + if buf.len() != 0 { + panic!("The `LogicalRead`s that are passed in must be freshly created using `LogicalRead::new`"); + } + // buf.cap() == 0 is ok + + // transition into Ongoing state + let req_len = buf.cap(); + *state = LogicalReadState::Ongoing(buf); + req_len + } + x => panic!("must only call with fresh LogicalReads, got another state, leaving Undefined state behind state={x:?}"), + }; + + // plan which chunks we need to read from + let mut remaining = req_len; + let mut chunk_no = *pos / (DIO_CHUNK_SIZE.into_u64()); + let mut offset_in_chunk = pos.into_usize() % DIO_CHUNK_SIZE; + while remaining > 0 { + let remaining_in_chunk = std::cmp::min(remaining, DIO_CHUNK_SIZE - offset_in_chunk); + by_chunk.entry(chunk_no).or_default().push(Interest { + logical_read, + offset_in_chunk: offset_in_chunk.into_u64(), + len: remaining_in_chunk.into_u64(), + }); + offset_in_chunk = 0; + chunk_no += 1; + remaining -= remaining_in_chunk; + } + } + + // At this point, we could iterate over by_chunk, in chunk order, + // read each chunk from disk, and fill the buffers. + // However, we can merge adjacent chunks into batches of MAX_CHUNK_BATCH_SIZE + // so we issue fewer IOs = fewer roundtrips = lower overall latency. + struct PhysicalRead<'a, B: Buffer> { + start_chunk_no: u64, + nchunks: usize, + dsts: Vec>, + } + struct PhysicalInterest<'a, B: Buffer> { + logical_read: &'a LogicalRead, + offset_in_physical_read: u64, + len: u64, + } + let mut physical_reads: Vec> = Vec::new(); + let mut by_chunk = by_chunk.into_iter().peekable(); + loop { + let mut last_chunk_no = None; + let to_merge: Vec<(u64, Vec>)> = by_chunk + .peeking_take_while(|(chunk_no, _)| { + if let Some(last_chunk_no) = last_chunk_no { + if *chunk_no != last_chunk_no + 1 { + return false; + } + } + last_chunk_no = Some(*chunk_no); + true + }) + .take(MAX_CHUNK_BATCH_SIZE) + .collect(); // TODO: avoid this .collect() + let Some(start_chunk_no) = to_merge.first().map(|(chunk_no, _)| *chunk_no) else { + break; + }; + let nchunks = to_merge.len(); + let dsts = to_merge + .into_iter() + .enumerate() + .flat_map(|(i, (_, dsts))| { + dsts.into_iter().map( + move |Interest { + logical_read, + offset_in_chunk, + len, + }| { + PhysicalInterest { + logical_read, + offset_in_physical_read: i + .checked_mul(DIO_CHUNK_SIZE) + .unwrap() + .into_u64() + + offset_in_chunk, + len, + } + }, + ) + }) + .collect(); + physical_reads.push(PhysicalRead { + start_chunk_no, + nchunks, + dsts, + }); + } + drop(by_chunk); + + // Execute physical reads and fill the logical read buffers + // TODO: pipelined reads; prefetch; + let get_io_buffer = |nchunks| Vec::with_capacity(nchunks * DIO_CHUNK_SIZE); + for PhysicalRead { + start_chunk_no, + nchunks, + dsts, + } in physical_reads + { + let all_done = dsts + .iter() + .all(|PhysicalInterest { logical_read, .. }| logical_read.state.borrow().is_terminal()); + if all_done { + continue; + } + let read_offset = start_chunk_no + .checked_mul(DIO_CHUNK_SIZE.into_u64()) + .expect("we produce chunk_nos by dividing by DIO_CHUNK_SIZE earlier"); + let io_buf = get_io_buffer(nchunks).slice_full(); + let req_len = io_buf.len(); + let (io_buf_slice, nread) = match file.read_exact_at_eof_ok(read_offset, io_buf, ctx).await + { + Ok(t) => t, + Err(e) => { + let e = Arc::new(e); + for PhysicalInterest { logical_read, .. } in dsts { + *logical_read.state.borrow_mut() = LogicalReadState::Error(Arc::clone(&e)); + // this will make later reads for the given LogicalRead short-circuit, see top of loop body + } + continue; + } + }; + let io_buf = io_buf_slice.into_inner(); + assert!( + nread <= io_buf.len(), + "the last chunk in the file can be a short read, so, no ==" + ); + let io_buf = &io_buf[..nread]; + for PhysicalInterest { + logical_read, + offset_in_physical_read, + len, + } in dsts + { + let mut logical_read_state_borrow = logical_read.state.borrow_mut(); + let logical_read_buf = match &mut *logical_read_state_borrow { + LogicalReadState::NotStarted(_) => { + unreachable!("we transition it into Ongoing at function entry") + } + LogicalReadState::Ongoing(buf) => buf, + LogicalReadState::Ok(_) | LogicalReadState::Error(_) => { + continue; + } + LogicalReadState::Undefined => unreachable!(), + }; + let range_in_io_buf = std::ops::Range { + start: offset_in_physical_read as usize, + end: offset_in_physical_read as usize + len as usize, + }; + assert!(range_in_io_buf.end >= range_in_io_buf.start); + if range_in_io_buf.end > nread { + let msg = format!( + "physical read returned EOF where this logical read expected more data in the file: offset=0x{read_offset:x} req_len=0x{req_len:x} nread=0x{nread:x} {:?}", + &*logical_read_state_borrow + ); + logical_read_state_borrow.transition_to_terminal(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + msg, + ))); + continue; + } + let data = &io_buf[range_in_io_buf]; + + // Copy data from io buffer into the logical read buffer. + // (And in debug mode, validate that the buffer impl adheres to the Buffer trait spec.) + let pre = if cfg!(debug_assertions) { + Some((logical_read_buf.len(), logical_read_buf.cap())) + } else { + None + }; + logical_read_buf.extend_from_slice(data); + let post = if cfg!(debug_assertions) { + Some((logical_read_buf.len(), logical_read_buf.cap())) + } else { + None + }; + match (pre, post) { + (None, None) => {} + (Some(_), None) | (None, Some(_)) => unreachable!(), + (Some((pre_len, pre_cap)), Some((post_len, post_cap))) => { + assert_eq!(pre_len + len as usize, post_len); + assert_eq!(pre_cap, post_cap); + } + } + + if logical_read_buf.len() == logical_read_buf.cap() { + logical_read_state_borrow.transition_to_terminal(Ok(())); + } + } + } + + if let Some(assert_logical_reads) = assert_logical_reads { + for logical_read in assert_logical_reads { + assert!(logical_read.state.borrow().is_terminal()); + } + } +} + +impl LogicalReadState { + fn is_terminal(&self) -> bool { + match self { + LogicalReadState::NotStarted(_) | LogicalReadState::Ongoing(_) => false, + LogicalReadState::Ok(_) | LogicalReadState::Error(_) => true, + LogicalReadState::Undefined => unreachable!(), + } + } + fn transition_to_terminal(&mut self, err: std::io::Result<()>) { + let cur = std::mem::replace(self, LogicalReadState::Undefined); + let buf = match cur { + LogicalReadState::Ongoing(buf) => buf, + x => panic!("must only call in state Ongoing, got {x:?}"), + }; + *self = match err { + Ok(()) => LogicalReadState::Ok(buf), + Err(e) => LogicalReadState::Error(Arc::new(e)), + }; + } +} + +impl std::fmt::Debug for LogicalReadState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + #[derive(Debug)] + #[allow(unused)] + struct BufferDebug { + len: usize, + cap: usize, + } + impl<'a> From<&'a dyn Buffer> for BufferDebug { + fn from(buf: &'a dyn Buffer) -> Self { + Self { + len: buf.len(), + cap: buf.cap(), + } + } + } + match self { + LogicalReadState::NotStarted(b) => { + write!(f, "NotStarted({:?})", BufferDebug::from(b as &dyn Buffer)) + } + LogicalReadState::Ongoing(b) => { + write!(f, "Ongoing({:?})", BufferDebug::from(b as &dyn Buffer)) + } + LogicalReadState::Ok(b) => write!(f, "Ok({:?})", BufferDebug::from(b as &dyn Buffer)), + LogicalReadState::Error(e) => write!(f, "Error({:?})", e), + LogicalReadState::Undefined => write!(f, "Undefined"), + } + } +} + +#[derive(Debug)] +struct RwLockRefCell(RwLock); +impl RwLockRefCell { + fn new(value: T) -> Self { + Self(RwLock::new(value)) + } + fn borrow(&self) -> impl std::ops::Deref + '_ { + self.0.try_read().unwrap() + } + fn borrow_mut(&self) -> impl std::ops::DerefMut + '_ { + self.0.try_write().unwrap() + } + fn into_inner(self) -> T { + self.0.into_inner().unwrap() + } +} + +impl Buffer for Vec { + fn cap(&self) -> usize { + self.capacity() + } + + fn len(&self) -> usize { + self.len() + } + + fn extend_from_slice(&mut self, src: &[u8]) { + if self.len() + src.len() > self.cap() { + panic!("Buffer capacity exceeded"); + } + Vec::extend_from_slice(self, src); + } +} + +#[cfg(test)] +#[allow(clippy::assertions_on_constants)] +mod tests { + use rand::Rng; + + use crate::{ + context::DownloadBehavior, task_mgr::TaskKind, + virtual_file::owned_buffers_io::slice::SliceMutExt, + }; + + use super::*; + use std::{cell::RefCell, collections::VecDeque}; + + struct InMemoryFile { + content: Vec, + } + + impl InMemoryFile { + fn new_random(len: usize) -> Self { + Self { + content: rand::thread_rng() + .sample_iter(rand::distributions::Standard) + .take(len) + .collect(), + } + } + fn test_logical_read(&self, pos: u64, len: usize) -> TestLogicalRead { + let expected_result = if pos as usize + len > self.content.len() { + Err("InMemoryFile short read".to_string()) + } else { + Ok(self.content[pos as usize..pos as usize + len].to_vec()) + }; + TestLogicalRead::new(pos, len, expected_result) + } + } + + #[test] + fn test_in_memory_file() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + let file = InMemoryFile::new_random(10); + let test_read = |pos, len| { + let buf = vec![0; len]; + let fut = file.read_exact_at_eof_ok(pos, buf.slice_full(), &ctx); + use futures::FutureExt; + let (slice, nread) = fut + .now_or_never() + .expect("impl never awaits") + .expect("impl never errors"); + let mut buf = slice.into_inner(); + buf.truncate(nread); + buf + }; + assert_eq!(test_read(0, 1), &file.content[0..1]); + assert_eq!(test_read(1, 2), &file.content[1..3]); + assert_eq!(test_read(9, 2), &file.content[9..]); + assert!(test_read(10, 2).is_empty()); + assert!(test_read(11, 2).is_empty()); + } + + impl File for InMemoryFile { + async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>( + &'b self, + start: u64, + mut dst: Slice, + _ctx: &'a RequestContext, + ) -> std::io::Result<(Slice, usize)> { + let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed(); + let nread = { + let req_len = dst_slice.len(); + let len = std::cmp::min(req_len, self.content.len().saturating_sub(start as usize)); + if start as usize >= self.content.len() { + 0 + } else { + dst_slice[..len] + .copy_from_slice(&self.content[start as usize..start as usize + len]); + len + } + }; + rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[nread..]); // to discover bugs + Ok((dst, nread)) + } + } + + #[derive(Clone)] + struct TestLogicalRead { + pos: u64, + len: usize, + expected_result: Result, String>, + } + + impl TestLogicalRead { + fn new(pos: u64, len: usize, expected_result: Result, String>) -> Self { + Self { + pos, + len, + expected_result, + } + } + fn make_logical_read(&self) -> LogicalRead> { + LogicalRead::new(self.pos, Vec::with_capacity(self.len)) + } + } + + async fn execute_and_validate_test_logical_reads( + file: &F, + test_logical_reads: I, + ctx: &RequestContext, + ) where + I: IntoIterator, + F: File, + { + let (tmp, test_logical_reads) = test_logical_reads.into_iter().tee(); + let logical_reads = tmp.map(|tr| tr.make_logical_read()).collect::>(); + execute(file, logical_reads.iter(), ctx).await; + for (logical_read, test_logical_read) in logical_reads.into_iter().zip(test_logical_reads) { + let actual = logical_read.into_result().expect("we call execute()"); + match (actual, test_logical_read.expected_result) { + (Ok(actual), Ok(expected)) if actual == expected => {} + (Err(actual), Err(expected)) => { + assert_eq!(actual.to_string(), expected); + } + (actual, expected) => panic!("expected {expected:?}\nactual {actual:?}"), + } + } + } + + #[tokio::test] + async fn test_blackbox() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + let cs = DIO_CHUNK_SIZE; + let cs_u64 = cs.into_u64(); + + let file = InMemoryFile::new_random(10 * cs); + + let test_logical_reads = vec![ + file.test_logical_read(0, 1), + // adjacent to logical_read0 + file.test_logical_read(1, 2), + // gap + // spans adjacent chunks + file.test_logical_read(cs_u64 - 1, 2), + // gap + // tail of chunk 3, all of chunk 4, and 2 bytes of chunk 5 + file.test_logical_read(3 * cs_u64 - 1, cs + 2), + // gap + file.test_logical_read(5 * cs_u64, 1), + ]; + let num_test_logical_reads = test_logical_reads.len(); + let test_logical_reads_perms = test_logical_reads + .into_iter() + .permutations(num_test_logical_reads); + + // test all orderings of LogicalReads, the order shouldn't matter for the results + for test_logical_reads in test_logical_reads_perms { + execute_and_validate_test_logical_reads(&file, test_logical_reads, &ctx).await; + } + } + + #[tokio::test] + #[should_panic] + async fn test_reusing_logical_reads_panics() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + let file = InMemoryFile::new_random(DIO_CHUNK_SIZE); + let a = file.test_logical_read(23, 10); + let logical_reads = vec![a.make_logical_read()]; + execute(&file, &logical_reads, &ctx).await; + // reuse pancis + execute(&file, &logical_reads, &ctx).await; + } + + struct RecorderFile<'a> { + recorded: RefCell>, + file: &'a InMemoryFile, + } + + struct RecordedRead { + pos: u64, + req_len: usize, + res: Vec, + } + + impl<'a> RecorderFile<'a> { + fn new(file: &'a InMemoryFile) -> RecorderFile<'a> { + Self { + recorded: Default::default(), + file, + } + } + } + + impl<'x> File for RecorderFile<'x> { + async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>( + &'b self, + start: u64, + dst: Slice, + ctx: &'a RequestContext, + ) -> std::io::Result<(Slice, usize)> { + let (dst, nread) = self.file.read_exact_at_eof_ok(start, dst, ctx).await?; + self.recorded.borrow_mut().push(RecordedRead { + pos: start, + req_len: dst.bytes_total(), + res: Vec::from(&dst[..nread]), + }); + Ok((dst, nread)) + } + } + + #[tokio::test] + async fn test_logical_reads_to_same_chunk_are_merged_into_one_chunk_read() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + + let file = InMemoryFile::new_random(2 * DIO_CHUNK_SIZE); + + let a = file.test_logical_read(DIO_CHUNK_SIZE.into_u64(), 10); + let b = file.test_logical_read(DIO_CHUNK_SIZE.into_u64() + 30, 20); + + let recorder = RecorderFile::new(&file); + + execute_and_validate_test_logical_reads(&recorder, vec![a, b], &ctx).await; + + let recorded = recorder.recorded.borrow(); + assert_eq!(recorded.len(), 1); + let RecordedRead { pos, req_len, .. } = &recorded[0]; + assert_eq!(*pos, DIO_CHUNK_SIZE.into_u64()); + assert_eq!(*req_len, DIO_CHUNK_SIZE); + } + + #[tokio::test] + async fn test_max_chunk_batch_size_is_respected() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + + let file = InMemoryFile::new_random(4 * MAX_CHUNK_BATCH_SIZE * DIO_CHUNK_SIZE); + + // read the 10th byte of each chunk 3 .. 3+2*MAX_CHUNK_BATCH_SIZE + assert!(3 < MAX_CHUNK_BATCH_SIZE, "test assumption"); + assert!(10 < DIO_CHUNK_SIZE, "test assumption"); + let mut test_logical_reads = Vec::new(); + for i in 3..3 + MAX_CHUNK_BATCH_SIZE + MAX_CHUNK_BATCH_SIZE / 2 { + test_logical_reads + .push(file.test_logical_read(i.into_u64() * DIO_CHUNK_SIZE.into_u64() + 10, 1)); + } + + let recorder = RecorderFile::new(&file); + + execute_and_validate_test_logical_reads(&recorder, test_logical_reads, &ctx).await; + + let recorded = recorder.recorded.borrow(); + assert_eq!(recorded.len(), 2); + { + let RecordedRead { pos, req_len, .. } = &recorded[0]; + assert_eq!(*pos as usize, 3 * DIO_CHUNK_SIZE); + assert_eq!(*req_len, MAX_CHUNK_BATCH_SIZE * DIO_CHUNK_SIZE); + } + { + let RecordedRead { pos, req_len, .. } = &recorded[1]; + assert_eq!(*pos as usize, (3 + MAX_CHUNK_BATCH_SIZE) * DIO_CHUNK_SIZE); + assert_eq!(*req_len, MAX_CHUNK_BATCH_SIZE / 2 * DIO_CHUNK_SIZE); + } + } + + #[tokio::test] + async fn test_batch_breaks_if_chunk_is_not_interesting() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + + assert!(MAX_CHUNK_BATCH_SIZE > 10, "test assumption"); + let file = InMemoryFile::new_random(3 * DIO_CHUNK_SIZE); + + let a = file.test_logical_read(0, 1); // chunk 0 + let b = file.test_logical_read(2 * DIO_CHUNK_SIZE.into_u64(), 1); // chunk 2 + + let recorder = RecorderFile::new(&file); + + execute_and_validate_test_logical_reads(&recorder, vec![a, b], &ctx).await; + + let recorded = recorder.recorded.borrow(); + + assert_eq!(recorded.len(), 2); + { + let RecordedRead { pos, req_len, .. } = &recorded[0]; + assert_eq!(*pos, 0); + assert_eq!(*req_len, DIO_CHUNK_SIZE); + } + { + let RecordedRead { pos, req_len, .. } = &recorded[1]; + assert_eq!(*pos, 2 * DIO_CHUNK_SIZE.into_u64()); + assert_eq!(*req_len, DIO_CHUNK_SIZE); + } + } + + struct ExpectedRead { + expect_pos: u64, + expect_len: usize, + respond: Result, String>, + } + + struct MockFile { + expected: RefCell>, + } + + impl Drop for MockFile { + fn drop(&mut self) { + assert!( + self.expected.borrow().is_empty(), + "expected reads not satisfied" + ); + } + } + + macro_rules! mock_file { + ($($pos:expr , $len:expr => $respond:expr),* $(,)?) => {{ + MockFile { + expected: RefCell::new(VecDeque::from(vec![$(ExpectedRead { + expect_pos: $pos, + expect_len: $len, + respond: $respond, + }),*])), + } + }}; + } + + impl File for MockFile { + async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>( + &'b self, + start: u64, + mut dst: Slice, + _ctx: &'a RequestContext, + ) -> std::io::Result<(Slice, usize)> { + let ExpectedRead { + expect_pos, + expect_len, + respond, + } = self + .expected + .borrow_mut() + .pop_front() + .expect("unexpected read"); + assert_eq!(start, expect_pos); + assert_eq!(dst.bytes_total(), expect_len); + match respond { + Ok(mocked_bytes) => { + let len = std::cmp::min(dst.bytes_total(), mocked_bytes.len()); + let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed(); + dst_slice[..len].copy_from_slice(&mocked_bytes[..len]); + rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[len..]); // to discover bugs + Ok((dst, len)) + } + Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e)), + } + } + } + + #[tokio::test] + async fn test_mock_file() { + // Self-test to ensure the relevant features of mock file work as expected. + + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + + let mock_file = mock_file! { + 0 , 512 => Ok(vec![0; 512]), + 512 , 512 => Ok(vec![1; 512]), + 1024 , 512 => Ok(vec![2; 10]), + 2048, 1024 => Err("foo".to_owned()), + }; + + let buf = Vec::with_capacity(512); + let (buf, nread) = mock_file + .read_exact_at_eof_ok(0, buf.slice_full(), &ctx) + .await + .unwrap(); + assert_eq!(nread, 512); + assert_eq!(&buf.into_inner()[..nread], &[0; 512]); + + let buf = Vec::with_capacity(512); + let (buf, nread) = mock_file + .read_exact_at_eof_ok(512, buf.slice_full(), &ctx) + .await + .unwrap(); + assert_eq!(nread, 512); + assert_eq!(&buf.into_inner()[..nread], &[1; 512]); + + let buf = Vec::with_capacity(512); + let (buf, nread) = mock_file + .read_exact_at_eof_ok(1024, buf.slice_full(), &ctx) + .await + .unwrap(); + assert_eq!(nread, 10); + assert_eq!(&buf.into_inner()[..nread], &[2; 10]); + + let buf = Vec::with_capacity(1024); + let err = mock_file + .read_exact_at_eof_ok(2048, buf.slice_full(), &ctx) + .await + .err() + .unwrap(); + assert_eq!(err.to_string(), "foo"); + } + + #[tokio::test] + async fn test_error_on_one_chunk_read_fails_only_dependent_logical_reads() { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + + let test_logical_reads = vec![ + // read spanning two batches + TestLogicalRead::new( + DIO_CHUNK_SIZE.into_u64() / 2, + MAX_CHUNK_BATCH_SIZE * DIO_CHUNK_SIZE, + Err("foo".to_owned()), + ), + // second read in failing chunk + TestLogicalRead::new( + (MAX_CHUNK_BATCH_SIZE * DIO_CHUNK_SIZE).into_u64() + DIO_CHUNK_SIZE.into_u64() - 10, + 5, + Err("foo".to_owned()), + ), + // read unaffected + TestLogicalRead::new( + (MAX_CHUNK_BATCH_SIZE * DIO_CHUNK_SIZE).into_u64() + + 2 * DIO_CHUNK_SIZE.into_u64() + + 10, + 5, + Ok(vec![1; 5]), + ), + ]; + let (tmp, test_logical_reads) = test_logical_reads.into_iter().tee(); + let test_logical_read_perms = tmp.permutations(test_logical_reads.len()); + + for test_logical_reads in test_logical_read_perms { + let file = mock_file!( + 0, MAX_CHUNK_BATCH_SIZE*DIO_CHUNK_SIZE => Ok(vec![0; MAX_CHUNK_BATCH_SIZE*DIO_CHUNK_SIZE]), + (MAX_CHUNK_BATCH_SIZE*DIO_CHUNK_SIZE).into_u64(), DIO_CHUNK_SIZE => Err("foo".to_owned()), + (MAX_CHUNK_BATCH_SIZE*DIO_CHUNK_SIZE + 2*DIO_CHUNK_SIZE).into_u64(), DIO_CHUNK_SIZE => Ok(vec![1; DIO_CHUNK_SIZE]), + ); + execute_and_validate_test_logical_reads(&file, test_logical_reads, &ctx).await; + } + } + + struct TestShortReadsSetup { + ctx: RequestContext, + file: InMemoryFile, + written: u64, + } + fn setup_short_chunk_read_tests() -> TestShortReadsSetup { + let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); + assert!(DIO_CHUNK_SIZE > 20, "test assumption"); + let written = (2 * DIO_CHUNK_SIZE - 10).into_u64(); + let file = InMemoryFile::new_random(written as usize); + TestShortReadsSetup { ctx, file, written } + } + + #[tokio::test] + async fn test_short_chunk_read_from_written_range() { + // Test what happens if there are logical reads + // that start within the last chunk, and + // the last chunk is not the full chunk length. + // + // The read should succeed despite the short chunk length. + let TestShortReadsSetup { ctx, file, written } = setup_short_chunk_read_tests(); + + let a = file.test_logical_read(written - 10, 5); + let recorder = RecorderFile::new(&file); + + execute_and_validate_test_logical_reads(&recorder, vec![a], &ctx).await; + + let recorded = recorder.recorded.borrow(); + assert_eq!(recorded.len(), 1); + let RecordedRead { pos, req_len, res } = &recorded[0]; + assert_eq!(*pos, DIO_CHUNK_SIZE.into_u64()); + assert_eq!(*req_len, DIO_CHUNK_SIZE); + assert_eq!(res, &file.content[DIO_CHUNK_SIZE..(written as usize)]); + } + + #[tokio::test] + async fn test_short_chunk_read_and_logical_read_from_unwritten_range() { + // Test what happens if there are logical reads + // that start within the last chunk, and + // the last chunk is not the full chunk length, and + // the logical reads end in the unwritten range. + // + // All should fail with UnexpectedEof and have the same IO pattern. + async fn the_impl(offset_delta: i64) { + let TestShortReadsSetup { ctx, file, written } = setup_short_chunk_read_tests(); + + let offset = u64::try_from( + i64::try_from(written) + .unwrap() + .checked_add(offset_delta) + .unwrap(), + ) + .unwrap(); + let a = file.test_logical_read(offset, 5); + let recorder = RecorderFile::new(&file); + let a_vr = a.make_logical_read(); + execute(&recorder, vec![&a_vr], &ctx).await; + + // validate the LogicalRead result + let a_res = a_vr.into_result().unwrap(); + let a_err = a_res.unwrap_err(); + assert_eq!(a_err.kind(), std::io::ErrorKind::UnexpectedEof); + + // validate the IO pattern + let recorded = recorder.recorded.borrow(); + assert_eq!(recorded.len(), 1); + let RecordedRead { pos, req_len, res } = &recorded[0]; + assert_eq!(*pos, DIO_CHUNK_SIZE.into_u64()); + assert_eq!(*req_len, DIO_CHUNK_SIZE); + assert_eq!(res, &file.content[DIO_CHUNK_SIZE..(written as usize)]); + } + + the_impl(-1).await; // start == length - 1 + the_impl(0).await; // start == length + the_impl(1).await; // start == length + 1 + } + + // TODO: mixed: some valid, some UnexpectedEof + + // TODO: same tests but with merges +} diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index 53bb66b95e..86a200ce28 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -1494,8 +1494,9 @@ impl LayerInner { let duration = SystemTime::now().duration_since(local_layer_mtime); match duration { Ok(elapsed) => { - let accessed = self.access_stats.accessed(); - if accessed { + let accessed_and_visible = self.access_stats.accessed() + && self.access_stats.visibility() == LayerVisibilityHint::Visible; + if accessed_and_visible { // Only layers used for reads contribute to our "low residence" metric that is used // to detect thrashing. Layers promoted for other reasons (e.g. compaction) are allowed // to be rapidly evicted without contributing to this metric. @@ -1509,7 +1510,7 @@ impl LayerInner { tracing::info!( residence_millis = elapsed.as_millis(), - accessed, + accessed_and_visible, "evicted layer after known residence period" ); } diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index 12f080f3c1..f5680ced90 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -192,20 +192,28 @@ async fn compaction_loop(tenant: Arc, cancel: CancellationToken) { } } - let started_at = Instant::now(); - let sleep_duration = if period == Duration::ZERO { + + let sleep_duration; + if period == Duration::ZERO { #[cfg(not(feature = "testing"))] info!("automatic compaction is disabled"); // check again in 10 seconds, in case it's been enabled again. - Duration::from_secs(10) + sleep_duration = Duration::from_secs(10) } else { + let iteration = Iteration { + started_at: Instant::now(), + period, + kind: BackgroundLoopKind::Compaction, + }; + // Run compaction - match tenant.compaction_iteration(&cancel, &ctx).await { + let IterationResult { output, elapsed } = iteration.run(tenant.compaction_iteration(&cancel, &ctx)).await; + match output { Ok(has_pending_task) => { error_run_count = 0; // schedule the next compaction immediately in case there is a pending compaction task - if has_pending_task { Duration::ZERO } else { period } + sleep_duration = if has_pending_task { Duration::ZERO } else { period }; } Err(e) => { let wait_duration = backoff::exponential_backoff_duration_seconds( @@ -221,16 +229,14 @@ async fn compaction_loop(tenant: Arc, cancel: CancellationToken) { &wait_duration, cancel.is_cancelled(), ); - wait_duration + sleep_duration = wait_duration; } } + + // the duration is recorded by performance tests by enabling debug in this function + tracing::debug!(elapsed_ms=elapsed.as_millis(), "compaction iteration complete"); }; - let elapsed = started_at.elapsed(); - warn_when_period_overrun(elapsed, period, BackgroundLoopKind::Compaction); - - // the duration is recorded by performance tests by enabling debug in this function - tracing::debug!(elapsed_ms=elapsed.as_millis(), "compaction iteration complete"); // Perhaps we did no work and the walredo process has been idle for some time: // give it a chance to shut down to avoid leaving walredo process running indefinitely. @@ -368,23 +374,27 @@ async fn gc_loop(tenant: Arc, cancel: CancellationToken) { } } - let started_at = Instant::now(); - let gc_horizon = tenant.get_gc_horizon(); - let sleep_duration = if period == Duration::ZERO || gc_horizon == 0 { + let sleep_duration; + if period == Duration::ZERO || gc_horizon == 0 { #[cfg(not(feature = "testing"))] info!("automatic GC is disabled"); // check again in 10 seconds, in case it's been enabled again. - Duration::from_secs(10) + sleep_duration = Duration::from_secs(10); } else { + let iteration = Iteration { + started_at: Instant::now(), + period, + kind: BackgroundLoopKind::Gc, + }; // Run gc - let res = tenant - .gc_iteration(None, gc_horizon, tenant.get_pitr_interval(), &cancel, &ctx) + let IterationResult { output, elapsed: _ } = + iteration.run(tenant.gc_iteration(None, gc_horizon, tenant.get_pitr_interval(), &cancel, &ctx)) .await; - match res { + match output { Ok(_) => { error_run_count = 0; - period + sleep_duration = period; } Err(crate::tenant::GcError::TenantCancelled) => { return; @@ -408,13 +418,11 @@ async fn gc_loop(tenant: Arc, cancel: CancellationToken) { error!("Gc failed {error_run_count} times, retrying in {wait_duration:?}: {e:?}"); } - wait_duration + sleep_duration = wait_duration; } } }; - warn_when_period_overrun(started_at.elapsed(), period, BackgroundLoopKind::Gc); - if tokio::time::timeout(sleep_duration, cancel.cancelled()) .await .is_ok() @@ -468,14 +476,12 @@ async fn ingest_housekeeping_loop(tenant: Arc, cancel: CancellationToken break; } - let started_at = Instant::now(); - tenant.ingest_housekeeping().await; - - warn_when_period_overrun( - started_at.elapsed(), + let iteration = Iteration { + started_at: Instant::now(), period, - BackgroundLoopKind::IngestHouseKeeping, - ); + kind: BackgroundLoopKind::IngestHouseKeeping, + }; + iteration.run(tenant.ingest_housekeeping()).await; } } .await; @@ -553,6 +559,54 @@ pub(crate) async fn delay_by_lease_length( } } +struct Iteration { + started_at: Instant, + period: Duration, + kind: BackgroundLoopKind, +} + +struct IterationResult { + output: O, + elapsed: Duration, +} + +impl Iteration { + #[instrument(skip_all)] + pub(crate) async fn run(self, fut: Fut) -> IterationResult + where + Fut: std::future::Future, + { + let Self { + started_at, + period, + kind, + } = self; + + let mut fut = std::pin::pin!(fut); + + // Wrap `fut` into a future that logs a message every `period` so that we get a + // very obvious breadcrumb in the logs _while_ a slow iteration is happening. + let liveness_logger = async move { + loop { + match tokio::time::timeout(period, &mut fut).await { + Ok(x) => return x, + Err(_) => { + // info level as per the same rationale why warn_when_period_overrun is info + // => https://github.com/neondatabase/neon/pull/5724 + info!("still running"); + } + } + } + }; + + let output = liveness_logger.await; + + let elapsed = started_at.elapsed(); + warn_when_period_overrun(elapsed, period, kind); + + IterationResult { output, elapsed } + } +} /// Attention: the `task` and `period` beocme labels of a pageserver-wide prometheus metric. pub(crate) fn warn_when_period_overrun( elapsed: Duration, diff --git a/pageserver/src/tenant/throttle.rs b/pageserver/src/tenant/throttle.rs index f3f3d5e3ae..f222e708e1 100644 --- a/pageserver/src/tenant/throttle.rs +++ b/pageserver/src/tenant/throttle.rs @@ -10,6 +10,7 @@ use std::{ use arc_swap::ArcSwap; use enumset::EnumSet; use tracing::{error, warn}; +use utils::leaky_bucket::{LeakyBucketConfig, RateLimiter}; use crate::{context::RequestContext, task_mgr::TaskKind}; @@ -33,8 +34,7 @@ pub struct Throttle { pub struct Inner { task_kinds: EnumSet, - rate_limiter: Arc, - config: Config, + rate_limiter: Arc, } pub type Config = pageserver_api::models::ThrottleConfig; @@ -77,8 +77,7 @@ where refill_interval, refill_amount, max, - fair, - } = &config; + } = config; let task_kinds: EnumSet = task_kinds .iter() .filter_map(|s| match TaskKind::from_str(s) { @@ -93,18 +92,21 @@ where } }) .collect(); + + // steady rate, we expect `refill_amount` requests per `refill_interval`. + // dividing gives us the rps. + let rps = f64::from(refill_amount.get()) / refill_interval.as_secs_f64(); + let config = LeakyBucketConfig::new(rps, f64::from(max)); + + // initial tracks how many tokens are available to put in the bucket + // we want how many tokens are currently in the bucket + let initial_tokens = max - initial; + + let rate_limiter = RateLimiter::with_initial_tokens(config, f64::from(initial_tokens)); + Inner { task_kinds, - rate_limiter: Arc::new( - leaky_bucket::RateLimiter::builder() - .initial(*initial) - .interval(*refill_interval) - .refill(refill_amount.get()) - .max(*max) - .fair(*fair) - .build(), - ), - config, + rate_limiter: Arc::new(rate_limiter), } } pub fn reconfigure(&self, config: Config) { @@ -127,7 +129,7 @@ where /// See [`Config::steady_rps`]. pub fn steady_rps(&self) -> f64 { - self.inner.load().config.steady_rps() + self.inner.load().rate_limiter.steady_rps() } pub async fn throttle(&self, ctx: &RequestContext, key_count: usize) -> Option { @@ -136,18 +138,9 @@ where return None; }; let start = std::time::Instant::now(); - let mut did_throttle = false; - let acquire = inner.rate_limiter.acquire(key_count); - // turn off runtime-induced preemption (aka coop) so our `did_throttle` is accurate - let acquire = tokio::task::unconstrained(acquire); - let mut acquire = std::pin::pin!(acquire); - std::future::poll_fn(|cx| { - use std::future::Future; - let poll = acquire.as_mut().poll(cx); - did_throttle = did_throttle || poll.is_pending(); - poll - }) - .await; + + let did_throttle = inner.rate_limiter.acquire(key_count).await; + self.count_accounted.fetch_add(1, Ordering::Relaxed); if did_throttle { self.count_throttled.fetch_add(1, Ordering::Relaxed); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 098c196ee8..35e0825bac 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -69,7 +69,7 @@ use crate::{ config::defaults::DEFAULT_PITR_INTERVAL, layer_map::{LayerMap, SearchResult}, metadata::TimelineMetadata, - storage_layer::PersistentLayerDesc, + storage_layer::{inmemory_layer::IndexEntry, PersistentLayerDesc}, }, walredo, }; @@ -218,7 +218,7 @@ pub(crate) struct RelSizeCache { } pub struct Timeline { - conf: &'static PageServerConf, + pub(crate) conf: &'static PageServerConf, tenant_conf: Arc>, myself: Weak, @@ -867,6 +867,11 @@ impl Timeline { .map(|ancestor| ancestor.timeline_id) } + /// Get the ancestor timeline + pub(crate) fn ancestor_timeline(&self) -> Option<&Arc> { + self.ancestor_timeline.as_ref() + } + /// Get the bytes written since the PITR cutoff on this branch, and /// whether this branch's ancestor_lsn is within its parent's PITR. pub(crate) fn get_pitr_history_stats(&self) -> (u64, bool) { @@ -1907,6 +1912,8 @@ impl Timeline { true } else if projected_layer_size >= checkpoint_distance { + // NB: this check is relied upon by: + let _ = IndexEntry::validate_checkpoint_distance; info!( "Will roll layer at {} with layer size {} due to layer size ({})", projected_lsn, layer_size, projected_layer_size @@ -5702,7 +5709,7 @@ impl<'a> TimelineWriter<'a> { return Ok(()); } - let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch); + let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch)?; let batch_max_lsn = serialized_batch.max_lsn; let buf_size: u64 = serialized_batch.raw.len() as u64; @@ -5739,6 +5746,12 @@ impl<'a> TimelineWriter<'a> { ctx: &RequestContext, ) -> anyhow::Result<()> { use utils::bin_ser::BeSer; + if !key.is_valid_key_on_write_path() { + bail!( + "the request contains data not supported by pageserver at TimelineWriter::put: {}", + key + ); + } let val_ser_size = value.serialized_size().unwrap() as usize; self.put_batch( vec![(key.to_compact(), lsn, val_ser_size, value.clone())], diff --git a/pageserver/src/tenant/vectored_blob_io.rs b/pageserver/src/tenant/vectored_blob_io.rs index 54a3ad789b..146bcf0e35 100644 --- a/pageserver/src/tenant/vectored_blob_io.rs +++ b/pageserver/src/tenant/vectored_blob_io.rs @@ -27,7 +27,7 @@ use utils::vec_map::VecMap; use crate::context::RequestContext; use crate::tenant::blob_io::{BYTE_UNCOMPRESSED, BYTE_ZSTD, LEN_COMPRESSION_BIT_MASK}; -use crate::virtual_file::VirtualFile; +use crate::virtual_file::{self, VirtualFile}; #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct MaxVectoredReadBytes(pub NonZeroUsize); @@ -60,7 +60,7 @@ pub struct VectoredBlobsBuf { pub struct VectoredRead { pub start: u64, pub end: u64, - /// Starting offsets and metadata for each blob in this read + /// Start offset and metadata for each blob in this read pub blobs_at: VecMap, } @@ -76,14 +76,109 @@ pub(crate) enum VectoredReadExtended { No, } -pub(crate) struct VectoredReadBuilder { +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum VectoredReadCoalesceMode { + /// Only coalesce exactly adjacent reads. + AdjacentOnly, + /// In addition to adjacent reads, also consider reads whose corresponding + /// `end` and `start` offsets reside at the same chunk. + Chunked(usize), +} + +impl VectoredReadCoalesceMode { + /// [`AdjacentVectoredReadBuilder`] is used if alignment requirement is 0, + /// whereas [`ChunkedVectoredReadBuilder`] is used for alignment requirement 1 and higher. + pub(crate) fn get() -> Self { + let align = virtual_file::get_io_buffer_alignment_raw(); + if align == 0 { + VectoredReadCoalesceMode::AdjacentOnly + } else { + VectoredReadCoalesceMode::Chunked(align) + } + } +} + +pub(crate) enum VectoredReadBuilder { + Adjacent(AdjacentVectoredReadBuilder), + Chunked(ChunkedVectoredReadBuilder), +} + +impl VectoredReadBuilder { + fn new_impl( + start_offset: u64, + end_offset: u64, + meta: BlobMeta, + max_read_size: Option, + mode: VectoredReadCoalesceMode, + ) -> Self { + match mode { + VectoredReadCoalesceMode::AdjacentOnly => Self::Adjacent( + AdjacentVectoredReadBuilder::new(start_offset, end_offset, meta, max_read_size), + ), + VectoredReadCoalesceMode::Chunked(chunk_size) => { + Self::Chunked(ChunkedVectoredReadBuilder::new( + start_offset, + end_offset, + meta, + max_read_size, + chunk_size, + )) + } + } + } + + pub(crate) fn new( + start_offset: u64, + end_offset: u64, + meta: BlobMeta, + max_read_size: usize, + mode: VectoredReadCoalesceMode, + ) -> Self { + Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), mode) + } + + pub(crate) fn new_streaming( + start_offset: u64, + end_offset: u64, + meta: BlobMeta, + mode: VectoredReadCoalesceMode, + ) -> Self { + Self::new_impl(start_offset, end_offset, meta, None, mode) + } + + pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended { + match self { + VectoredReadBuilder::Adjacent(builder) => builder.extend(start, end, meta), + VectoredReadBuilder::Chunked(builder) => builder.extend(start, end, meta), + } + } + + pub(crate) fn build(self) -> VectoredRead { + match self { + VectoredReadBuilder::Adjacent(builder) => builder.build(), + VectoredReadBuilder::Chunked(builder) => builder.build(), + } + } + + pub(crate) fn size(&self) -> usize { + match self { + VectoredReadBuilder::Adjacent(builder) => builder.size(), + VectoredReadBuilder::Chunked(builder) => builder.size(), + } + } +} + +pub(crate) struct AdjacentVectoredReadBuilder { + /// Start offset of the read. start: u64, + // End offset of the read. end: u64, + /// Start offset and metadata for each blob in this read blobs_at: VecMap, max_read_size: Option, } -impl VectoredReadBuilder { +impl AdjacentVectoredReadBuilder { /// Start building a new vectored read. /// /// Note that by design, this does not check against reading more than `max_read_size` to @@ -93,7 +188,7 @@ impl VectoredReadBuilder { start_offset: u64, end_offset: u64, meta: BlobMeta, - max_read_size: usize, + max_read_size: Option, ) -> Self { let mut blobs_at = VecMap::default(); blobs_at @@ -104,7 +199,7 @@ impl VectoredReadBuilder { start: start_offset, end: end_offset, blobs_at, - max_read_size: Some(max_read_size), + max_read_size, } } /// Attempt to extend the current read with a new blob if the start @@ -113,13 +208,15 @@ impl VectoredReadBuilder { pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended { tracing::trace!(start, end, "trying to extend"); let size = (end - start) as usize; - if self.end == start && { + let not_limited_by_max_read_size = { if let Some(max_read_size) = self.max_read_size { self.size() + size <= max_read_size } else { true } - } { + }; + + if self.end == start && not_limited_by_max_read_size { self.end = end; self.blobs_at .append(start, meta) @@ -144,6 +241,107 @@ impl VectoredReadBuilder { } } +pub(crate) struct ChunkedVectoredReadBuilder { + /// Start block number + start_blk_no: usize, + /// End block number (exclusive). + end_blk_no: usize, + /// Start offset and metadata for each blob in this read + blobs_at: VecMap, + max_read_size: Option, + /// Chunk size reads are coalesced into. + chunk_size: usize, +} + +/// Computes x / d rounded up. +fn div_round_up(x: usize, d: usize) -> usize { + (x + (d - 1)) / d +} + +impl ChunkedVectoredReadBuilder { + /// Start building a new vectored read. + /// + /// Note that by design, this does not check against reading more than `max_read_size` to + /// support reading larger blobs than the configuration value. The builder will be single use + /// however after that. + pub(crate) fn new( + start_offset: u64, + end_offset: u64, + meta: BlobMeta, + max_read_size: Option, + chunk_size: usize, + ) -> Self { + let mut blobs_at = VecMap::default(); + blobs_at + .append(start_offset, meta) + .expect("First insertion always succeeds"); + + let start_blk_no = start_offset as usize / chunk_size; + let end_blk_no = div_round_up(end_offset as usize, chunk_size); + Self { + start_blk_no, + end_blk_no, + blobs_at, + max_read_size, + chunk_size, + } + } + + /// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk. + /// + /// The resulting size also must be below the max read size. + pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended { + tracing::trace!(start, end, "trying to extend"); + let start_blk_no = start as usize / self.chunk_size; + let end_blk_no = div_round_up(end as usize, self.chunk_size); + + let not_limited_by_max_read_size = { + if let Some(max_read_size) = self.max_read_size { + let coalesced_size = (end_blk_no - self.start_blk_no) * self.chunk_size; + coalesced_size <= max_read_size + } else { + true + } + }; + + // True if the second block starts in the same block or the immediate next block where the first block ended. + // + // Note: This automatically handles the case where two blocks are adjacent to each other, + // whether they starts on chunk size boundary or not. + let is_adjacent_chunk_read = { + // 1. first.end & second.start are in the same block + self.end_blk_no == start_blk_no + 1 || + // 2. first.end ends one block before second.start + self.end_blk_no == start_blk_no + }; + + if is_adjacent_chunk_read && not_limited_by_max_read_size { + self.end_blk_no = end_blk_no; + self.blobs_at + .append(start, meta) + .expect("LSNs are ordered within vectored reads"); + + return VectoredReadExtended::Yes; + } + + VectoredReadExtended::No + } + + pub(crate) fn size(&self) -> usize { + (self.end_blk_no - self.start_blk_no) * self.chunk_size + } + + pub(crate) fn build(self) -> VectoredRead { + let start = (self.start_blk_no * self.chunk_size) as u64; + let end = (self.end_blk_no * self.chunk_size) as u64; + VectoredRead { + start, + end, + blobs_at: self.blobs_at, + } + } +} + #[derive(Copy, Clone, Debug)] pub enum BlobFlag { None, @@ -166,14 +364,18 @@ pub struct VectoredReadPlanner { prev: Option<(Key, Lsn, u64, BlobFlag)>, max_read_size: usize, + + mode: VectoredReadCoalesceMode, } impl VectoredReadPlanner { pub fn new(max_read_size: usize) -> Self { + let mode = VectoredReadCoalesceMode::get(); Self { blobs: BTreeMap::new(), prev: None, max_read_size, + mode, } } @@ -252,6 +454,7 @@ impl VectoredReadPlanner { end_offset, BlobMeta { key, lsn }, self.max_read_size, + self.mode, ); let prev_read_builder = current_read_builder.replace(next_read_builder); @@ -303,6 +506,18 @@ impl<'a> VectoredBlobReader<'a> { read.size(), buf.capacity() ); + + if cfg!(debug_assertions) { + let align = virtual_file::get_io_buffer_alignment() as u64; + debug_assert_eq!( + read.start % align, + 0, + "Read start at {} does not satisfy the required io buffer alignment ({} bytes)", + read.start, + align + ); + } + let mut buf = self .file .read_exact_at(buf.slice(0..read.size()), read.start, ctx) @@ -310,27 +525,20 @@ impl<'a> VectoredBlobReader<'a> { .into_inner(); let blobs_at = read.blobs_at.as_slice(); - let start_offset = blobs_at.first().expect("VectoredRead is never empty").0; + + let start_offset = read.start; let mut metas = Vec::with_capacity(blobs_at.len()); - // Blobs in `read` only provide their starting offset. The end offset // of a blob is implicit: the start of the next blob if one exists // or the end of the read. - let pairs = blobs_at.iter().zip( - blobs_at - .iter() - .map(Some) - .skip(1) - .chain(std::iter::once(None)), - ); // Some scratch space, put here for reusing the allocation let mut decompressed_vec = Vec::new(); - for ((offset, meta), next) in pairs { - let offset_in_buf = offset - start_offset; - let first_len_byte = buf[offset_in_buf as usize]; + for (blob_start, meta) in blobs_at { + let blob_start_in_buf = blob_start - start_offset; + let first_len_byte = buf[blob_start_in_buf as usize]; // Each blob is prefixed by a header containing its size and compression information. // Extract the size and skip that header to find the start of the data. @@ -340,7 +548,7 @@ impl<'a> VectoredBlobReader<'a> { (1, first_len_byte as u64, BYTE_UNCOMPRESSED) } else { let mut blob_size_buf = [0u8; 4]; - let offset_in_buf = offset_in_buf as usize; + let offset_in_buf = blob_start_in_buf as usize; blob_size_buf.copy_from_slice(&buf[offset_in_buf..offset_in_buf + 4]); blob_size_buf[0] &= !LEN_COMPRESSION_BIT_MASK; @@ -353,12 +561,8 @@ impl<'a> VectoredBlobReader<'a> { ) }; - let start_raw = offset_in_buf + size_length; - let end_raw = match next { - Some((next_blob_start_offset, _)) => next_blob_start_offset - start_offset, - None => start_raw + blob_size, - }; - assert_eq!(end_raw - start_raw, blob_size); + let start_raw = blob_start_in_buf + size_length; + let end_raw = start_raw + blob_size; let (start, end); if compression_bits == BYTE_UNCOMPRESSED { start = start_raw as usize; @@ -407,18 +611,22 @@ pub struct StreamingVectoredReadPlanner { max_cnt: usize, /// Size of the current batch cnt: usize, + + mode: VectoredReadCoalesceMode, } impl StreamingVectoredReadPlanner { pub fn new(max_read_size: u64, max_cnt: usize) -> Self { assert!(max_cnt > 0); assert!(max_read_size > 0); + let mode = VectoredReadCoalesceMode::get(); Self { read_builder: None, prev: None, max_cnt, max_read_size, cnt: 0, + mode, } } @@ -467,17 +675,12 @@ impl StreamingVectoredReadPlanner { } None => { self.read_builder = { - let mut blobs_at = VecMap::default(); - blobs_at - .append(start_offset, BlobMeta { key, lsn }) - .expect("First insertion always succeeds"); - - Some(VectoredReadBuilder { - start: start_offset, - end: end_offset, - blobs_at, - max_read_size: None, - }) + Some(VectoredReadBuilder::new_streaming( + start_offset, + end_offset, + BlobMeta { key, lsn }, + self.mode, + )) }; } } @@ -511,7 +714,9 @@ mod tests { use super::*; fn validate_read(read: &VectoredRead, offset_range: &[(Key, Lsn, u64, BlobFlag)]) { - assert_eq!(read.start, offset_range.first().unwrap().2); + let align = virtual_file::get_io_buffer_alignment() as u64; + assert_eq!(read.start % align, 0); + assert_eq!(read.start / align, offset_range.first().unwrap().2 / align); let expected_offsets_in_read: Vec<_> = offset_range.iter().map(|o| o.2).collect(); @@ -525,6 +730,68 @@ mod tests { assert_eq!(expected_offsets_in_read, offsets_in_read); } + #[test] + fn planner_chunked_coalesce_all_test() { + use crate::virtual_file; + + let chunk_size = virtual_file::get_io_buffer_alignment() as u64; + + // The test explicitly does not check chunk size < 512 + if chunk_size < 512 { + return; + } + + let max_read_size = chunk_size as usize * 8; + let key = Key::MIN; + let lsn = Lsn(0); + + let blob_descriptions = [ + (key, lsn, chunk_size / 8, BlobFlag::None), // Read 1 BEGIN + (key, lsn, chunk_size / 4, BlobFlag::Ignore), // Gap + (key, lsn, chunk_size / 2, BlobFlag::None), + (key, lsn, chunk_size - 2, BlobFlag::Ignore), // Gap + (key, lsn, chunk_size, BlobFlag::None), + (key, lsn, chunk_size * 2 - 1, BlobFlag::None), + (key, lsn, chunk_size * 2 + 1, BlobFlag::Ignore), // Gap + (key, lsn, chunk_size * 3 + 1, BlobFlag::None), + (key, lsn, chunk_size * 5 + 1, BlobFlag::None), + (key, lsn, chunk_size * 6 + 1, BlobFlag::Ignore), // skipped chunk size, but not a chunk: should coalesce. + (key, lsn, chunk_size * 7 + 1, BlobFlag::None), + (key, lsn, chunk_size * 8, BlobFlag::None), // Read 2 BEGIN (b/c max_read_size) + (key, lsn, chunk_size * 9, BlobFlag::Ignore), // ==== skipped a chunk + (key, lsn, chunk_size * 10, BlobFlag::None), // Read 3 BEGIN (cannot coalesce) + ]; + + let ranges = [ + &[ + blob_descriptions[0], + blob_descriptions[2], + blob_descriptions[4], + blob_descriptions[5], + blob_descriptions[7], + blob_descriptions[8], + blob_descriptions[10], + ], + &blob_descriptions[11..12], + &blob_descriptions[13..], + ]; + + let mut planner = VectoredReadPlanner::new(max_read_size); + for (key, lsn, offset, flag) in blob_descriptions { + planner.handle(key, lsn, offset, flag); + } + + planner.handle_range_end(652 * 1024); + + let reads = planner.finish(); + + assert_eq!(reads.len(), ranges.len()); + + for (idx, read) in reads.iter().enumerate() { + validate_read(read, ranges[idx]); + } + } + #[test] fn planner_max_read_size_test() { let max_read_size = 128 * 1024; @@ -571,18 +838,19 @@ mod tests { #[test] fn planner_replacement_test() { - let max_read_size = 128 * 1024; + let chunk_size = virtual_file::get_io_buffer_alignment() as u64; + let max_read_size = 128 * chunk_size as usize; let first_key = Key::MIN; let second_key = first_key.next(); let lsn = Lsn(0); let blob_descriptions = vec![ - (first_key, lsn, 0, BlobFlag::None), // First in read 1 - (first_key, lsn, 1024, BlobFlag::None), // Last in read 1 - (second_key, lsn, 2 * 1024, BlobFlag::ReplaceAll), - (second_key, lsn, 3 * 1024, BlobFlag::None), - (second_key, lsn, 4 * 1024, BlobFlag::ReplaceAll), // First in read 2 - (second_key, lsn, 5 * 1024, BlobFlag::None), // Last in read 2 + (first_key, lsn, 0, BlobFlag::None), // First in read 1 + (first_key, lsn, chunk_size, BlobFlag::None), // Last in read 1 + (second_key, lsn, 2 * chunk_size, BlobFlag::ReplaceAll), + (second_key, lsn, 3 * chunk_size, BlobFlag::None), + (second_key, lsn, 4 * chunk_size, BlobFlag::ReplaceAll), // First in read 2 + (second_key, lsn, 5 * chunk_size, BlobFlag::None), // Last in read 2 ]; let ranges = [&blob_descriptions[0..2], &blob_descriptions[4..]]; @@ -592,7 +860,7 @@ mod tests { planner.handle(key, lsn, offset, flag); } - planner.handle_range_end(6 * 1024); + planner.handle_range_end(6 * chunk_size); let reads = planner.finish(); assert_eq!(reads.len(), 2); @@ -737,6 +1005,7 @@ mod tests { let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16; let mut buf = BytesMut::with_capacity(reserved_bytes); + let mode = VectoredReadCoalesceMode::get(); let vectored_blob_reader = VectoredBlobReader::new(&file); let meta = BlobMeta { key: Key::MIN, @@ -748,7 +1017,7 @@ mod tests { if idx + 1 == offsets.len() { continue; } - let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096); + let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, mode); let read = read_builder.build(); let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?; assert_eq!(result.blobs.len(), 1); @@ -784,4 +1053,12 @@ mod tests { round_trip_test_compressed(&blobs, true).await?; Ok(()) } + + #[test] + fn test_div_round_up() { + const CHUNK_SIZE: usize = 512; + assert_eq!(1, div_round_up(200, CHUNK_SIZE)); + assert_eq!(1, div_round_up(CHUNK_SIZE, CHUNK_SIZE)); + assert_eq!(2, div_round_up(CHUNK_SIZE + 1, CHUNK_SIZE)); + } } diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index c0017280fd..97d966e2da 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -10,6 +10,7 @@ //! This is similar to PostgreSQL's virtual file descriptor facility in //! src/backend/storage/file/fd.c //! +use crate::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT; use crate::context::RequestContext; use crate::metrics::{StorageIoOperation, STORAGE_IO_SIZE, STORAGE_IO_TIME_METRIC}; @@ -1140,10 +1141,13 @@ impl OpenFiles { /// server startup. /// #[cfg(not(test))] -pub fn init(num_slots: usize, engine: IoEngineKind) { +pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize) { if OPEN_FILES.set(OpenFiles::new(num_slots)).is_err() { panic!("virtual_file::init called twice"); } + if set_io_buffer_alignment(io_buffer_alignment).is_err() { + panic!("IO buffer alignment ({io_buffer_alignment}) is not a power of two"); + } io_engine::init(engine); crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64); } @@ -1167,6 +1171,53 @@ fn get_open_files() -> &'static OpenFiles { } } +static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT); + +/// Returns true if `x` is zero or a power of two. +fn is_zero_or_power_of_two(x: usize) -> bool { + (x == 0) || ((x & (x - 1)) == 0) +} + +#[allow(unused)] +pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> { + if is_zero_or_power_of_two(align) { + IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed); + Ok(()) + } else { + Err(align) + } +} + +/// Gets the io buffer alignment requirement. Returns 0 if there is no requirement specified. +/// +/// This function should be used to check the raw config value. +pub(crate) fn get_io_buffer_alignment_raw() -> usize { + let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed); + + if cfg!(test) { + let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT"; + if let Some(test_align) = utils::env::var(env_var_name) { + if is_zero_or_power_of_two(test_align) { + test_align + } else { + panic!("IO buffer alignment ({test_align}) is not a power of two"); + } + } else { + align + } + } else { + align + } +} + +/// Gets the io buffer alignment requirement. Returns 1 if the alignment config is set to zero. +/// +/// This function should be used for getting the actual alignment value to use. +pub(crate) fn get_io_buffer_alignment() -> usize { + let align = get_io_buffer_alignment_raw(); + align.max(1) +} + #[cfg(test)] mod tests { use crate::context::DownloadBehavior; diff --git a/pageserver/src/virtual_file/owned_buffers_io/write.rs b/pageserver/src/virtual_file/owned_buffers_io/write.rs index f8f37b17e3..568cf62e56 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/write.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/write.rs @@ -78,6 +78,7 @@ where .expect("must not use after we returned an error") } + /// Guarantees that if Ok() is returned, all bytes in `chunk` have been accepted. #[cfg_attr(target_os = "macos", allow(dead_code))] pub async fn write_buffered( &mut self, diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 8425528740..8ccd20adb1 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -21,19 +21,25 @@ //! redo Postgres process, but some records it can handle directly with //! bespoken Rust code. +use std::time::Duration; +use std::time::SystemTime; + use pageserver_api::shard::ShardIdentity; use postgres_ffi::v14::nonrelfile_utils::clogpage_precedes; use postgres_ffi::v14::nonrelfile_utils::slru_may_delete_clogsegment; +use postgres_ffi::TimestampTz; use postgres_ffi::{fsm_logical_to_physical, page_is_new, page_set_lsn}; use anyhow::{bail, Context, Result}; use bytes::{Buf, Bytes, BytesMut}; use tracing::*; use utils::failpoint_support; +use utils::rate_limit::RateLimit; use crate::context::RequestContext; use crate::metrics::WAL_INGEST; use crate::pgdatadir_mapping::{DatadirModification, Version}; +use crate::span::debug_assert_current_span_has_tenant_and_timeline_id; use crate::tenant::PageReconstructError; use crate::tenant::Timeline; use crate::walrecord::*; @@ -53,6 +59,13 @@ pub struct WalIngest { shard: ShardIdentity, checkpoint: CheckPoint, checkpoint_modified: bool, + warn_ingest_lag: WarnIngestLag, +} + +struct WarnIngestLag { + lag_msg_ratelimit: RateLimit, + future_lsn_msg_ratelimit: RateLimit, + timestamp_invalid_msg_ratelimit: RateLimit, } impl WalIngest { @@ -71,6 +84,11 @@ impl WalIngest { shard: *timeline.get_shard_identity(), checkpoint, checkpoint_modified: false, + warn_ingest_lag: WarnIngestLag { + lag_msg_ratelimit: RateLimit::new(std::time::Duration::from_secs(10)), + future_lsn_msg_ratelimit: RateLimit::new(std::time::Duration::from_secs(10)), + timestamp_invalid_msg_ratelimit: RateLimit::new(std::time::Duration::from_secs(10)), + }, }) } @@ -1212,6 +1230,48 @@ impl WalIngest { Ok(()) } + fn warn_on_ingest_lag( + &mut self, + conf: &crate::config::PageServerConf, + wal_timestmap: TimestampTz, + ) { + debug_assert_current_span_has_tenant_and_timeline_id(); + let now = SystemTime::now(); + let rate_limits = &mut self.warn_ingest_lag; + match try_from_pg_timestamp(wal_timestmap) { + Ok(ts) => { + match now.duration_since(ts) { + Ok(lag) => { + if lag > conf.wait_lsn_timeout { + rate_limits.lag_msg_ratelimit.call2(|rate_limit_stats| { + let lag = humantime::format_duration(lag); + warn!(%rate_limit_stats, %lag, "ingesting record with timestamp lagging more than wait_lsn_timeout"); + }) + } + }, + Err(e) => { + let delta_t = e.duration(); + // determined by prod victoriametrics query: 1000 * (timestamp(node_time_seconds{neon_service="pageserver"}) - node_time_seconds) + // => https://www.robustperception.io/time-metric-from-the-node-exporter/ + const IGNORED_DRIFT: Duration = Duration::from_millis(100); + if delta_t > IGNORED_DRIFT { + let delta_t = humantime::format_duration(delta_t); + rate_limits.future_lsn_msg_ratelimit.call2(|rate_limit_stats| { + warn!(%rate_limit_stats, %delta_t, "ingesting record with timestamp from future"); + }) + } + } + }; + + } + Err(error) => { + rate_limits.timestamp_invalid_msg_ratelimit.call2(|rate_limit_stats| { + warn!(%rate_limit_stats, %error, "ingesting record with invalid timestamp, cannot calculate lag and will fail find-lsn-for-timestamp type queries"); + }) + } + } + } + /// Subroutine of ingest_record(), to handle an XLOG_XACT_* records. /// async fn ingest_xact_record( @@ -1228,6 +1288,8 @@ impl WalIngest { let mut rpageno = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT; let mut page_xids: Vec = vec![parsed.xid]; + self.warn_on_ingest_lag(modification.tline.conf, parsed.xact_time); + for subxact in &parsed.subxacts { let subxact_pageno = subxact / pg_constants::CLOG_XACTS_PER_PAGE; if subxact_pageno != pageno { @@ -2303,6 +2365,9 @@ mod tests { let _endpoint = Lsn::from_hex("1FFFF98").unwrap(); let harness = TenantHarness::create("test_ingest_real_wal").await.unwrap(); + let span = harness + .span() + .in_scope(|| info_span!("timeline_span", timeline_id=%TIMELINE_ID)); let (tenant, ctx) = harness.load().await; let remote_initdb_path = @@ -2354,6 +2419,7 @@ mod tests { while let Some((lsn, recdata)) = decoder.poll_decode().unwrap() { walingest .ingest_record(recdata, lsn, &mut modification, &mut decoded, &ctx) + .instrument(span.clone()) .await .unwrap(); } diff --git a/patches/pg_hintplan.patch b/patches/pg_hint_plan.patch similarity index 55% rename from patches/pg_hintplan.patch rename to patches/pg_hint_plan.patch index 61a5ecbb90..4039a036df 100644 --- a/patches/pg_hintplan.patch +++ b/patches/pg_hint_plan.patch @@ -1,13 +1,7 @@ -commit f7925d4d1406c0f0229e3c691c94b69e381899b1 (HEAD -> master) -Author: Alexey Masterov -Date: Thu Jun 6 08:02:42 2024 +0000 - - Patch expected files to consider Neon's log messages - -diff --git a/ext-src/pg_hint_plan-src/expected/ut-A.out b/ext-src/pg_hint_plan-src/expected/ut-A.out -index da723b8..f8d0102 100644 ---- a/ext-src/pg_hint_plan-src/expected/ut-A.out -+++ b/ext-src/pg_hint_plan-src/expected/ut-A.out +diff --git a/expected/ut-A.out b/expected/ut-A.out +index da723b8..5328114 100644 +--- a/expected/ut-A.out ++++ b/expected/ut-A.out @@ -9,13 +9,16 @@ SET search_path TO public; ---- -- No.A-1-1-3 @@ -25,10 +19,18 @@ index da723b8..f8d0102 100644 DROP SCHEMA other_schema; ---- ---- No. A-5-1 comment pattern -diff --git a/ext-src/pg_hint_plan-src/expected/ut-fdw.out b/ext-src/pg_hint_plan-src/expected/ut-fdw.out +@@ -3175,6 +3178,7 @@ SELECT s.query, s.calls + FROM public.pg_stat_statements s + JOIN pg_catalog.pg_database d + ON (s.dbid = d.oid) ++ WHERE s.query LIKE 'SELECT * FROM s1.t1%' OR s.query LIKE '%pg_stat_statements_reset%' + ORDER BY 1; + query | calls + --------------------------------------+------- +diff --git a/expected/ut-fdw.out b/expected/ut-fdw.out index d372459..6282afe 100644 ---- a/ext-src/pg_hint_plan-src/expected/ut-fdw.out -+++ b/ext-src/pg_hint_plan-src/expected/ut-fdw.out +--- a/expected/ut-fdw.out ++++ b/expected/ut-fdw.out @@ -7,6 +7,7 @@ SET pg_hint_plan.debug_print TO on; SET client_min_messages TO LOG; SET pg_hint_plan.enable_hint TO on; @@ -37,3 +39,15 @@ index d372459..6282afe 100644 CREATE SERVER file_server FOREIGN DATA WRAPPER file_fdw; CREATE USER MAPPING FOR PUBLIC SERVER file_server; CREATE FOREIGN TABLE ft1 (id int, val int) SERVER file_server OPTIONS (format 'csv', filename :'filename'); +diff --git a/sql/ut-A.sql b/sql/ut-A.sql +index 7c7d58a..4fd1a07 100644 +--- a/sql/ut-A.sql ++++ b/sql/ut-A.sql +@@ -963,6 +963,7 @@ SELECT s.query, s.calls + FROM public.pg_stat_statements s + JOIN pg_catalog.pg_database d + ON (s.dbid = d.oid) ++ WHERE s.query LIKE 'SELECT * FROM s1.t1%' OR s.query LIKE '%pg_stat_statements_reset%' + ORDER BY 1; + + ---- diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 73a001b6ba..5126c26c5d 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -550,9 +550,6 @@ pageserver_connect(shardno_t shard_no, int elevel) case 2: pagestream_query = psprintf("pagestream_v2 %s %s", neon_tenant, neon_timeline); break; - case 1: - pagestream_query = psprintf("pagestream %s %s", neon_tenant, neon_timeline); - break; default: elog(ERROR, "unexpected neon_protocol_version %d", neon_protocol_version); } @@ -1063,7 +1060,7 @@ pg_init_libpagestore(void) NULL, &neon_protocol_version, 2, /* use protocol version 2 */ - 1, /* min */ + 2, /* min */ 2, /* max */ PGC_SU_BACKEND, 0, /* no flags required */ diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 8951e6607b..1f196d016c 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -87,9 +87,8 @@ typedef enum { * can skip traversing through recent layers which we know to not contain any * versions for the requested page. * - * These structs describe the V2 of these requests. The old V1 protocol contained - * just one LSN and a boolean 'latest' flag. If the neon_protocol_version GUC is - * set to 1, we will convert these to the V1 requests before sending. + * These structs describe the V2 of these requests. (The old now-defunct V1 + * protocol contained just one LSN and a boolean 'latest' flag.) */ typedef struct { diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 8edaf65639..7f39c7d026 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -1001,51 +1001,10 @@ nm_pack_request(NeonRequest *msg) initStringInfo(&s); - if (neon_protocol_version >= 2) - { - pq_sendbyte(&s, msg->tag); - pq_sendint64(&s, msg->lsn); - pq_sendint64(&s, msg->not_modified_since); - } - else - { - bool latest; - XLogRecPtr lsn; + pq_sendbyte(&s, msg->tag); + pq_sendint64(&s, msg->lsn); + pq_sendint64(&s, msg->not_modified_since); - /* - * In primary, we always request the latest page version. - */ - if (!RecoveryInProgress()) - { - latest = true; - lsn = msg->not_modified_since; - } - else - { - /* - * In the protocol V1, we cannot represent that we want to read - * page at LSN X, and we know that it hasn't been modified since - * Y. We can either use 'not_modified_lsn' as the request LSN, and - * risk getting an error if that LSN is too old and has already - * fallen out of the pageserver's GC horizon, or we can send - * 'request_lsn', causing the pageserver to possibly wait for the - * recent WAL to arrive unnecessarily. Or something in between. We - * choose to use the old LSN and risk GC errors, because that's - * what we've done historically. - */ - latest = false; - lsn = msg->not_modified_since; - } - - pq_sendbyte(&s, msg->tag); - pq_sendbyte(&s, latest); - pq_sendint64(&s, lsn); - } - - /* - * The rest of the request messages are the same between protocol V1 and - * V2 - */ switch (messageTag(msg)) { /* pagestore_client -> pagestore */ diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index f3ddc64061..65ef588ba5 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -220,6 +220,64 @@ nwp_register_gucs(void) NULL, NULL, NULL); } + +static int +split_safekeepers_list(char *safekeepers_list, char *safekeepers[]) +{ + int n_safekeepers = 0; + char *curr_sk = safekeepers_list; + + for (char *coma = safekeepers_list; coma != NULL && *coma != '\0'; curr_sk = coma) + { + if (++n_safekeepers >= MAX_SAFEKEEPERS) { + wpg_log(FATAL, "too many safekeepers"); + } + + coma = strchr(coma, ','); + safekeepers[n_safekeepers-1] = curr_sk; + + if (coma != NULL) { + *coma++ = '\0'; + } + } + + return n_safekeepers; +} + +/* + * Accept two coma-separated strings with list of safekeeper host:port addresses. + * Split them into arrays and return false if two sets do not match, ignoring the order. + */ +static bool +safekeepers_cmp(char *old, char *new) +{ + char *safekeepers_old[MAX_SAFEKEEPERS]; + char *safekeepers_new[MAX_SAFEKEEPERS]; + int len_old = 0; + int len_new = 0; + + len_old = split_safekeepers_list(old, safekeepers_old); + len_new = split_safekeepers_list(new, safekeepers_new); + + if (len_old != len_new) + { + return false; + } + + qsort(&safekeepers_old, len_old, sizeof(char *), pg_qsort_strcmp); + qsort(&safekeepers_new, len_new, sizeof(char *), pg_qsort_strcmp); + + for (int i = 0; i < len_new; i++) + { + if (strcmp(safekeepers_old[i], safekeepers_new[i]) != 0) + { + return false; + } + } + + return true; +} + /* * GUC assign_hook for neon.safekeepers. Restarts walproposer through FATAL if * the list changed. @@ -235,19 +293,26 @@ assign_neon_safekeepers(const char *newval, void *extra) wpg_log(FATAL, "neon.safekeepers is empty"); } + /* Copy values because we will modify them in split_safekeepers_list() */ + char *newval_copy = pstrdup(newval); + char *oldval = pstrdup(wal_acceptors_list); + /* * TODO: restarting through FATAL is stupid and introduces 1s delay before * next bgw start. We should refactor walproposer to allow graceful exit and * thus remove this delay. + * XXX: If you change anything here, sync with test_safekeepers_reconfigure_reorder. */ - if (strcmp(wal_acceptors_list, newval) != 0) + if (!safekeepers_cmp(oldval, newval_copy)) { wpg_log(FATAL, "restarting walproposer to change safekeeper list from %s to %s", wal_acceptors_list, newval); } + pfree(newval_copy); + pfree(oldval); } -/* Check if we need to suspend inserts because of lagging replication. */ +/* Check if we need to suspend inserts because of lagging replication. */ static uint64 backpressure_lag_impl(void) { diff --git a/proxy/README.md b/proxy/README.md index afc8b77db8..8d850737be 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -6,7 +6,7 @@ Proxy binary accepts `--auth-backend` CLI option, which determines auth scheme a new SCRAM-based console API; uses SNI info to select the destination project (endpoint soon) * postgres uses postgres to select auth secrets of existing roles. Useful for local testing -* link +* web (or link) sends login link for all usernames Also proxy can expose following services to the external world: diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 3b3c571129..7c408f817c 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,20 +1,20 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::BackendType; +pub use backend::Backend; mod credentials; -pub use credentials::{ +pub(crate) use credentials::{ check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, ComputeUserInfoParseError, IpPattern, }; mod password_hack; -pub use password_hack::parse_endpoint_param; +pub(crate) use password_hack::parse_endpoint_param; use password_hack::PasswordHackPayload; mod flow; -pub use flow::*; +pub(crate) use flow::*; use tokio::time::error::Elapsed; use crate::{ @@ -25,13 +25,13 @@ use std::{io, net::IpAddr}; use thiserror::Error; /// Convenience wrapper for the authentication error. -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; /// Common authentication error. #[derive(Debug, Error)] -pub enum AuthErrorImpl { +pub(crate) enum AuthErrorImpl { #[error(transparent)] - Link(#[from] backend::LinkAuthError), + Web(#[from] backend::WebAuthError), #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), @@ -77,30 +77,30 @@ pub enum AuthErrorImpl { #[derive(Debug, Error)] #[error(transparent)] -pub struct AuthError(Box); +pub(crate) struct AuthError(Box); impl AuthError { - pub fn bad_auth_method(name: impl Into>) -> Self { + pub(crate) fn bad_auth_method(name: impl Into>) -> Self { AuthErrorImpl::BadAuthMethod(name.into()).into() } - pub fn auth_failed(user: impl Into>) -> Self { + pub(crate) fn auth_failed(user: impl Into>) -> Self { AuthErrorImpl::AuthFailed(user.into()).into() } - pub fn ip_address_not_allowed(ip: IpAddr) -> Self { + pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self { AuthErrorImpl::IpAddressNotAllowed(ip).into() } - pub fn too_many_connections() -> Self { + pub(crate) fn too_many_connections() -> Self { AuthErrorImpl::TooManyConnections.into() } - pub fn is_auth_failed(&self) -> bool { + pub(crate) fn is_auth_failed(&self) -> bool { matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) } - pub fn user_timeout(elapsed: Elapsed) -> Self { + pub(crate) fn user_timeout(elapsed: Elapsed) -> Self { AuthErrorImpl::UserTimeout(elapsed).into() } } @@ -114,7 +114,7 @@ impl> From for AuthError { impl UserFacingError for AuthError { fn to_string_client(&self) -> String { match self.0.as_ref() { - AuthErrorImpl::Link(e) => e.to_string_client(), + AuthErrorImpl::Web(e) => e.to_string_client(), AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(), AuthErrorImpl::Sasl(e) => e.to_string_client(), AuthErrorImpl::AuthFailed(_) => self.to_string(), @@ -132,7 +132,7 @@ impl UserFacingError for AuthError { impl ReportableError for AuthError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self.0.as_ref() { - AuthErrorImpl::Link(e) => e.get_error_kind(), + AuthErrorImpl::Web(e) => e.get_error_kind(), AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(), AuthErrorImpl::Sasl(e) => e.get_error_kind(), AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index bb9a0ddffc..1d28c6df31 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -1,19 +1,19 @@ mod classic; mod hacks; pub mod jwt; -mod link; pub mod local; +mod web; use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; use ipnet::{Ipv4Net, Ipv6Net}; -pub use link::LinkAuthError; use local::LocalBackend; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; use tracing::{info, warn}; +pub(crate) use web::WebAuthError; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::{validate_password_and_exchange, AuthError}; @@ -65,24 +65,24 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T, D> { +pub enum Backend<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>, D), + Web(MaybeOwned<'a, url::ApiUrl>, D), /// Local proxy uses configured auth credentials and does not wake compute Local(MaybeOwned<'a, LocalBackend>), } -pub trait TestBackend: Send + Sync + 'static { +#[cfg(test)] +pub(crate) trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; fn get_allowed_ips_and_secret( &self, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; - fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, (), ()> { +impl std::fmt::Display for Backend<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Console(api, ()) => match &**api { @@ -96,73 +96,73 @@ impl std::fmt::Display for BackendType<'_, (), ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Self::Link(url, ()) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(), Self::Local(_) => fmt.debug_tuple("Local").finish(), } } } -impl BackendType<'_, T, D> { +impl Backend<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T, &D> { + pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> { match self { - Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x), - Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x), - Self::Local(l) => BackendType::Local(MaybeOwned::Borrowed(l)), + Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x), + Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x), + Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)), } } } -impl<'a, T, D> BackendType<'a, T, D> { +impl<'a, T, D> Backend<'a, T, D> { /// Very similar to [`std::option::Option::map`]. - /// Maps [`BackendType`] to [`BackendType`] by applying + /// Maps [`Backend`] to [`Backend`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { + pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> { match self { - Self::Console(c, x) => BackendType::Console(c, f(x)), - Self::Link(c, x) => BackendType::Link(c, x), - Self::Local(l) => BackendType::Local(l), + Self::Console(c, x) => Backend::Console(c, f(x)), + Self::Web(c, x) => Backend::Web(c, x), + Self::Local(l) => Backend::Local(l), } } } -impl<'a, T, D, E> BackendType<'a, Result, D> { +impl<'a, T, D, E> Backend<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { + pub(crate) fn transpose(self) -> Result, E> { match self { - Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)), - Self::Link(c, x) => Ok(BackendType::Link(c, x)), - Self::Local(l) => Ok(BackendType::Local(l)), + Self::Console(c, x) => x.map(|x| Backend::Console(c, x)), + Self::Web(c, x) => Ok(Backend::Web(c, x)), + Self::Local(l) => Ok(Backend::Local(l)), } } } -pub struct ComputeCredentials { - pub info: ComputeUserInfo, - pub keys: ComputeCredentialKeys, +pub(crate) struct ComputeCredentials { + pub(crate) info: ComputeUserInfo, + pub(crate) keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] -pub struct ComputeUserInfoNoEndpoint { - pub user: RoleName, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfoNoEndpoint { + pub(crate) user: RoleName, + pub(crate) options: NeonOptions, } #[derive(Debug, Clone)] -pub struct ComputeUserInfo { - pub endpoint: EndpointId, - pub user: RoleName, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfo { + pub(crate) endpoint: EndpointId, + pub(crate) user: RoleName, + pub(crate) options: NeonOptions, } impl ComputeUserInfo { - pub fn endpoint_cache_key(&self) -> EndpointCacheKey { + pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey { self.options.get_cache_key(&self.endpoint) } } -pub enum ComputeCredentialKeys { +pub(crate) enum ComputeCredentialKeys { Password(Vec), AuthKeys(AuthKeys), None, @@ -222,7 +222,7 @@ impl RateBucketInfo { } impl AuthenticationConfig { - pub fn check_rate_limit( + pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, config: &AuthenticationConfig, @@ -403,35 +403,26 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { - /// Get compute endpoint name from the credentials. - pub fn get_endpoint(&self) -> Option { - match self { - Self::Console(_, user_info) => user_info.endpoint_id.clone(), - Self::Link(_, ()) => Some("link".into()), - Self::Local(_) => Some("local".into()), - } - } - +impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get username from the credentials. - pub fn get_user(&self) -> &str { + pub(crate) fn get_user(&self) -> &str { match self { Self::Console(_, user_info) => &user_info.user, - Self::Link(_, ()) => "link", + Self::Web(_, ()) => "web", Self::Local(_) => "local", } } /// Authenticate the client via the requested backend, possibly using credentials. #[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)] - pub async fn authenticate( + pub(crate) async fn authenticate( self, ctx: &RequestMonitoring, client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result> { + ) -> auth::Result> { let res = match self { Self::Console(api, user_info) => { info!( @@ -450,15 +441,15 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { endpoint_rate_limiter, ) .await?; - BackendType::Console(api, credentials) + Backend::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Self::Link(url, ()) => { - info!("performing link authentication"); + Self::Web(url, ()) => { + info!("performing web authentication"); - let info = link::authenticate(ctx, &url, client).await?; + let info = web::authenticate(ctx, &url, client).await?; - BackendType::Link(url, info) + Backend::Web(url, info) } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) @@ -470,39 +461,39 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { } } -impl BackendType<'_, ComputeUserInfo, &()> { - pub async fn get_role_secret( +impl Backend<'_, ComputeUserInfo, &()> { + pub(crate) async fn get_role_secret( &self, ctx: &RequestMonitoring, ) -> Result { match self { Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Link(_, ()) => Ok(Cached::new_uncached(None)), + Self::Web(_, ()) => Ok(Cached::new_uncached(None)), Self::Local(_) => Ok(Cached::new_uncached(None)), } } - pub async fn get_allowed_ips_and_secret( + pub(crate) async fn get_allowed_ips_and_secret( &self, ctx: &RequestMonitoring, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { match self { Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Self::Link(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), } } } #[async_trait::async_trait] -impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { +impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result { match self { Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())), + Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())), Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } @@ -510,21 +501,23 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { fn get_keys(&self) -> &ComputeCredentialKeys { match self { Self::Console(_, creds) => &creds.keys, - Self::Link(_, _) => &ComputeCredentialKeys::None, + Self::Web(_, _) => &ComputeCredentialKeys::None, Self::Local(_) => &ComputeCredentialKeys::None, } } } #[async_trait::async_trait] -impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { +impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> { async fn wake_compute( &self, ctx: &RequestMonitoring, ) -> Result { match self { Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Self::Link(_, ()) => unreachable!("link auth flow doesn't support waking the compute"), + Self::Web(_, ()) => { + unreachable!("web auth flow doesn't support waking the compute") + } Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } @@ -532,7 +525,7 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { fn get_keys(&self) -> &ComputeCredentialKeys { match self { Self::Console(_, creds) => &creds.keys, - Self::Link(_, ()) => &ComputeCredentialKeys::None, + Self::Web(_, ()) => &ComputeCredentialKeys::None, Self::Local(_) => &ComputeCredentialKeys::None, } } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 56921dd949..e9019ce2cf 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -17,7 +17,7 @@ use tracing::{info, warn}; /// one round trip and *expensive* computations (>= 4096 HMAC iterations). /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. -pub async fn authenticate_cleartext( +pub(crate) async fn authenticate_cleartext( ctx: &RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, @@ -59,7 +59,7 @@ pub async fn authenticate_cleartext( /// Workaround for clients which don't provide an endpoint (project) name. /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) -pub async fn password_hack_no_authentication( +pub(crate) async fn password_hack_no_authentication( ctx: &RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 61833e19ed..1f44e4af5d 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -22,27 +22,27 @@ const MAX_RENEW: Duration = Duration::from_secs(3600); const MAX_JWK_BODY_SIZE: usize = 64 * 1024; /// How to get the JWT auth rules -pub trait FetchAuthRules: Clone + Send + Sync + 'static { +pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static { fn fetch_auth_rules( &self, role_name: RoleName, ) -> impl Future>> + Send; } -pub struct AuthRule { - pub id: String, - pub jwks_url: url::Url, - pub audience: Option, +pub(crate) struct AuthRule { + pub(crate) id: String, + pub(crate) jwks_url: url::Url, + pub(crate) audience: Option, } #[derive(Default)] -pub struct JwkCache { +pub(crate) struct JwkCache { client: reqwest::Client, map: DashMap<(EndpointId, RoleName), Arc>, } -pub struct JwkCacheEntry { +pub(crate) struct JwkCacheEntry { /// Should refetch at least every hour to verify when old keys have been removed. /// Should refetch when new key IDs are seen only every 5 minutes or so last_retrieved: Instant, @@ -75,7 +75,7 @@ impl KeySet { } } -pub struct JwkCacheEntryLock { +pub(crate) struct JwkCacheEntryLock { cached: ArcSwapOption, lookup: tokio::sync::Semaphore, } @@ -309,7 +309,7 @@ impl JwkCacheEntryLock { } impl JwkCache { - pub async fn check_jwt( + pub(crate) async fn check_jwt( &self, ctx: &RequestMonitoring, endpoint: EndpointId, @@ -500,6 +500,7 @@ mod tests { use hyper1::service::service_fn; use hyper_util::rt::TokioIo; use rand::rngs::OsRng; + use rsa::pkcs8::DecodePrivateKey; use signature::Signer; use tokio::net::TcpListener; @@ -517,8 +518,8 @@ mod tests { (sk, jwk) } - fn new_rsa_jwk(kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) { - let sk = rsa::RsaPrivateKey::new(&mut OsRng, 2048).unwrap(); + fn new_rsa_jwk(key: &str, kid: String) -> (rsa::RsaPrivateKey, jose_jwk::Jwk) { + let sk = rsa::RsaPrivateKey::from_pkcs8_pem(key).unwrap(); let pk = sk.to_public_key().into(); let jwk = jose_jwk::Jwk { key: jose_jwk::Key::Rsa(pk), @@ -569,10 +570,70 @@ mod tests { format!("{payload}.{sig}") } + // RSA key gen is slow.... + const RS1: &str = "-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNuWBIWTlo+54Y +aifpGInIrpv6LlsbI/2/2CC81Arlx4RsABORklgA9XSGwaCbHTshHsfd1S916JwA +SpjyPQYWfqo6iAV8a4MhjIeJIkRr74prDCSzOGZvIc6VaGeCIb9clf3HSrPHm3hA +cfLMB8/p5MgoxERPDOIn3XYoS9SEEuP7l0LkmEZMerg6W6lDjQRDny0Lb50Jky9X +mDqnYXBhs99ranbwL5vjy0ba6OIeCWFJme5u+rv5C/P0BOYrJfGxIcEoKa8Ukw5s +PlM+qrz9ope1eOuXMNNdyFDReNBUyaM1AwBAayU5rz57crer7K/UIofaJ42T4cMM +nx/SWfBNAgMBAAECggEACqdpBxYn1PoC6/zDaFzu9celKEWyTiuE/qRwvZa1ocS9 +ZOJ0IPvVNud/S2NHsADJiSOQ8joSJScQvSsf1Ju4bv3MTw+wSQtAVUJz2nQ92uEi +5/xPAkEPfP3hNvebNLAOuvrBk8qYmOPCTIQaMNrOt6wzeXkAmJ9wLuRXNCsJLHW+ +KLpf2WdgTYxqK06ZiJERFgJ2r1MsC2IgTydzjOAdEIrtMarerTLqqCpwFrk/l0cz +1O2OAb17ZxmhuzMhjNMin81c8F2fZAGMeOjn92Jl5kUsYw/pG+0S8QKlbveR/fdP +We2tJsgXw2zD0q7OJpp8NXS2yddrZGyysYsof983wQKBgQD2McqNJqo+eWL5zony +UbL19loYw0M15EjhzIuzW1Jk0rPj65yQyzpJ6pqicRuWr34MvzCx+ZHM2b3jSiNu +GES2fnC7xLIKyeRxfqsXF71xz+6UStEGRQX27r1YWEtyQVuBhvlqB+AGWP3PYAC+ +HecZecnZ+vcihJ2K3+l5O3paVQKBgQDV6vKH5h2SY9vgO8obx0P7XSS+djHhmPuU +f8C/Fq6AuRbIA1g04pzuLU2WS9T26eIjgM173uVNg2TuqJveWzz+CAAp6nCR6l24 +DBg49lMGCWrMo4FqPG46QkUqvK8uSj42GkX/e5Rut1Gyu0209emeM6h2d2K15SvY +9563tYSmGQKBgQDwcH5WTi20KA7e07TroJi8GKWzS3gneNUpGQBS4VxdtV4UuXXF +/4TkzafJ/9cm2iurvUmMd6XKP9lw0mY5zp/E70WgTCBp4vUlVsU3H2tYbO+filYL +3ntNx6nKTykX4/a/UJfj0t8as+zli+gNxNx/h+734V9dKdFG4Rl+2fTLpQKBgQCE +qJkTEe+Q0wCOBEYICADupwqcWqwAXWDW7IrZdfVtulqYWwqecVIkmk+dPxWosc4d +ekjz4nyNH0i+gC15LVebqdaAJ/T7aD4KXuW+nXNLMRfcJCGjgipRUruWD0EMEdqW +rqBuGXMpXeH6VxGPgVkJVLvKC6tZZe9VM+pnvteuMQKBgQC8GaL+Lz+al4biyZBf +JE8ekWrIotq/gfUBLP7x70+PB9bNtXtlgmTvjgYg4jiu3KR/ZIYYQ8vfVgkb6tDI +rWGZw86Pzuoi1ppg/pYhKk9qrmCIT4HPEXbHl7ATahu2BOCIU3hybjTh2lB6LbX9 +8LMFlz1QPqSZYN/A/kOcLBfa3A== +-----END PRIVATE KEY----- +"; + const RS2: &str = "-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDipm6FIKSRab3J +HwmK18t7hp+pohllxIDUSPi7S5mIhN/JG2Plq2Lp746E/fuT8dcBF2R4sJlG2L0J +zmxOvBU/i/sQF9s1i4CEfg05k2//gKENIEsF3pMMmrH+mcZi0TTD6rezHpdVxPHk +qWxSyOCtIJV29X+wxPwAB59kQFHzy2ooPB1isZcpE8tO0KthAM+oZ3KuCwE0++cO +IWLeq9aPwyKhtip/xjTMxd1kzdKh592mGSyzr9D0QSWOYFGvgJXANDdiPdhSSOLt +ECWPNPlm2FQvGGvYYBafUqz7VumKHE6x8J6lKdYa2J0ZdDzCIo2IHzlxe+RZNgwy +uAD2jhVxAgMBAAECggEAbsZHWBu3MzcKQiVARbLoygvnN0J5xUqAaMDtiKUPejDv +K1yOu67DXnDuKEP2VL2rhuYG/hHaKE1AP227c9PrUq6424m9YvM2sgrlrdFIuQkG +LeMtp8W7+zoUasp/ssZrUqICfLIj5xCl5UuFHQT/Ar7dLlIYwa3VOLKBDb9+Dnfe +QH5/So4uMXG6vw34JN9jf+eAc8Yt0PeIz62ycvRwdpTJQ0MxZN9ZKpCAQp+VTuXT +zlzNvDMilabEdqUvAyGyz8lBLNl0wdaVrqPqAEWM5U45QXsdFZknWammP7/tijeX +0z+Bi0J0uSEU5X502zm7GArj/NNIiWMcjmDjwUUhwQKBgQD9C2GoqxOxuVPYqwYR ++Jz7f2qMjlSP8adA5Lzuh8UKXDp8JCEQC8ryweLzaOKS9C5MAw+W4W2wd4nJoQI1 +P1dgGvBlfvEeRHMgqWtq7FuTsjSe7e0uSEkC4ngDb4sc0QOpv15cMuEz+4+aFLPL +x29EcHWAaBX+rkid3zpQHFU4eQKBgQDlTCEqRuXwwa3V+Sq+mNWzD9QIGtD87TH/ +FPO/Ij/cK2+GISgFDqhetiGTH4qrvPL0psPT+iH5zGFYcoFmTtwLdWQJdxhxz0bg +iX/AceyX5e1Bm+ThT36sU83NrxKPkrdk6jNmr2iUF1OTzTwUKOYdHOPZqdMPfF4M +4XAaWVT2uQKBgQD4nKcNdU+7LE9Rr+4d1/o8Klp/0BMK/ayK2HE7lc8kt6qKb2DA +iCWUTqPw7Fq3cQrPia5WWhNP7pJEtFkcAaiR9sW7onW5fBz0uR+dhK0QtmR2xWJj +N4fsOp8ZGQ0/eae0rh1CTobucLkM9EwV6VLLlgYL67e4anlUCo8bSEr+WQKBgQCB +uf6RgqcY/RqyklPCnYlZ0zyskS9nyXKd1GbK3j+u+swP4LZZlh9f5j88k33LCA2U +qLzmMwAB6cWxWqcnELqhqPq9+ClWSmTZKDGk2U936NfAZMirSGRsbsVi9wfTPriP +WYlXMSpDjqb0WgsBhNob4npubQxCGKTFOM5Jufy90QKBgB0Lte1jX144uaXx6dtB +rjXNuWNir0Jy31wHnQuCA+XnfUgPcrKmRLm8taMbXgZwxkNvgFkpUWU8aPEK08Ne +X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL +5JiconnI5aLek0QVPoFaVXFa +-----END PRIVATE KEY----- +"; + #[tokio::test] async fn renew() { - let (rs1, jwk1) = new_rsa_jwk("1".into()); - let (rs2, jwk2) = new_rsa_jwk("2".into()); + let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into()); + let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into()); let (ec1, jwk3) = new_ec_jwk("3".into()); let (ec2, jwk4) = new_ec_jwk("4".into()); diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 6d18564dd6..8124f568cf 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -16,16 +16,14 @@ use crate::{ use super::jwt::{AuthRule, FetchAuthRules, JwkCache}; pub struct LocalBackend { - pub jwks_cache: JwkCache, - pub postgres_addr: SocketAddr, - pub node_info: NodeInfo, + pub(crate) jwks_cache: JwkCache, + pub(crate) node_info: NodeInfo, } impl LocalBackend { pub fn new(postgres_addr: SocketAddr) -> Self { LocalBackend { jwks_cache: JwkCache::default(), - postgres_addr, node_info: NodeInfo { config: { let mut cfg = ConnCfg::new(); @@ -47,7 +45,7 @@ impl LocalBackend { } #[derive(Clone, Copy)] -pub struct StaticAuthRules; +pub(crate) struct StaticAuthRules; pub static JWKS_ROLE_MAP: ArcSwapOption = ArcSwapOption::const_empty(); diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/web.rs similarity index 86% rename from proxy/src/auth/backend/link.rs rename to proxy/src/auth/backend/web.rs index 95f4614736..58a4bef62e 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/web.rs @@ -13,7 +13,7 @@ use tokio_postgres::config::SslMode; use tracing::{info, info_span}; #[derive(Debug, Error)] -pub enum LinkAuthError { +pub(crate) enum WebAuthError { #[error(transparent)] WaiterRegister(#[from] waiters::RegisterError), @@ -24,18 +24,18 @@ pub enum LinkAuthError { Io(#[from] std::io::Error), } -impl UserFacingError for LinkAuthError { +impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() } } -impl ReportableError for LinkAuthError { +impl ReportableError for WebAuthError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service, - LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service, - LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect, + Self::WaiterRegister(_) => crate::error::ErrorKind::Service, + Self::WaiterWait(_) => crate::error::ErrorKind::Service, + Self::Io(_) => crate::error::ErrorKind::ClientDisconnect, } } } @@ -52,7 +52,7 @@ fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String { ) } -pub fn new_psql_session_id() -> String { +pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } @@ -74,7 +74,7 @@ pub(super) async fn authenticate( } }; - let span = info_span!("link", psql_session_id = &psql_session_id); + let span = info_span!("web", psql_session_id = &psql_session_id); let greeting = hello_message(link_uri, &psql_session_id); // Give user a URL to spawn a new database. @@ -87,7 +87,7 @@ pub(super) async fn authenticate( // Wait for web console response (see `mgmt`). info!(parent: &span, "waiting for console's reply..."); - let db_info = waiter.await.map_err(LinkAuthError::from)?; + let db_info = waiter.await.map_err(WebAuthError::from)?; client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index cb06fcaf55..0e91ae570a 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -16,7 +16,7 @@ use thiserror::Error; use tracing::{info, warn}; #[derive(Debug, Error, PartialEq, Eq, Clone)] -pub enum ComputeUserInfoParseError { +pub(crate) enum ComputeUserInfoParseError { #[error("Parameter '{0}' is missing in startup packet.")] MissingKey(&'static str), @@ -51,20 +51,20 @@ impl ReportableError for ComputeUserInfoParseError { /// Various client credentials which we use for authentication. /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ComputeUserInfoMaybeEndpoint { - pub user: RoleName, - pub endpoint_id: Option, - pub options: NeonOptions, +pub(crate) struct ComputeUserInfoMaybeEndpoint { + pub(crate) user: RoleName, + pub(crate) endpoint_id: Option, + pub(crate) options: NeonOptions, } impl ComputeUserInfoMaybeEndpoint { #[inline] - pub fn endpoint(&self) -> Option<&str> { + pub(crate) fn endpoint(&self) -> Option<&str> { self.endpoint_id.as_deref() } } -pub fn endpoint_sni( +pub(crate) fn endpoint_sni( sni: &str, common_names: &HashSet, ) -> Result, ComputeUserInfoParseError> { @@ -83,7 +83,7 @@ pub fn endpoint_sni( } impl ComputeUserInfoMaybeEndpoint { - pub fn parse( + pub(crate) fn parse( ctx: &RequestMonitoring, params: &StartupMessageParams, sni: Option<&str>, @@ -173,12 +173,12 @@ impl ComputeUserInfoMaybeEndpoint { } } -pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { +pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool { ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern)) } #[derive(Debug, Clone, Eq, PartialEq)] -pub enum IpPattern { +pub(crate) enum IpPattern { Subnet(ipnet::IpNet), Range(IpAddr, IpAddr), Single(IpAddr), diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index acf7b4f6b6..f7e2b5296e 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -17,17 +17,20 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; /// Every authentication selector is supposed to implement this trait. -pub trait AuthMethod { +pub(crate) trait AuthMethod { /// Any authentication selector should provide initial backend message /// containing auth method name and parameters, e.g. md5 salt. fn first_message(&self, channel_binding: bool) -> BeMessage<'_>; } /// Initial state of [`AuthFlow`]. -pub struct Begin; +pub(crate) struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring); +pub(crate) struct Scram<'a>( + pub(crate) &'a scram::ServerSecret, + pub(crate) &'a RequestMonitoring, +); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -44,7 +47,7 @@ impl AuthMethod for Scram<'_> { /// Use an ad hoc auth flow (for clients which don't support SNI) proposed in /// . -pub struct PasswordHack; +pub(crate) struct PasswordHack; impl AuthMethod for PasswordHack { #[inline(always)] @@ -55,10 +58,10 @@ impl AuthMethod for PasswordHack { /// Use clear-text password auth called `password` in docs /// -pub struct CleartextPassword { - pub pool: Arc, - pub endpoint: EndpointIdInt, - pub secret: AuthSecret, +pub(crate) struct CleartextPassword { + pub(crate) pool: Arc, + pub(crate) endpoint: EndpointIdInt, + pub(crate) secret: AuthSecret, } impl AuthMethod for CleartextPassword { @@ -70,7 +73,7 @@ impl AuthMethod for CleartextPassword { /// This wrapper for [`PqStream`] performs client authentication. #[must_use] -pub struct AuthFlow<'a, S, State> { +pub(crate) struct AuthFlow<'a, S, State> { /// The underlying stream which implements libpq's protocol. stream: &'a mut PqStream>, /// State might contain ancillary data (see [`Self::begin`]). @@ -81,7 +84,7 @@ pub struct AuthFlow<'a, S, State> { /// Initial state of the stream wrapper. impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { /// Create a new wrapper for client authentication. - pub fn new(stream: &'a mut PqStream>) -> Self { + pub(crate) fn new(stream: &'a mut PqStream>) -> Self { let tls_server_end_point = stream.get_ref().tls_server_end_point(); Self { @@ -92,7 +95,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { } /// Move to the next step by sending auth method's name & params to client. - pub async fn begin(self, method: M) -> io::Result> { + pub(crate) async fn begin(self, method: M) -> io::Result> { self.stream .write_message(&method.first_message(self.tls_server_end_point.supported())) .await?; @@ -107,7 +110,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { impl AuthFlow<'_, S, PasswordHack> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn get_password(self) -> super::Result { + pub(crate) async fn get_password(self) -> super::Result { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -126,7 +129,7 @@ impl AuthFlow<'_, S, PasswordHack> { impl AuthFlow<'_, S, CleartextPassword> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result> { + pub(crate) async fn authenticate(self) -> super::Result> { let msg = self.stream.read_password_message().await?; let password = msg .strip_suffix(&[0]) @@ -151,7 +154,7 @@ impl AuthFlow<'_, S, CleartextPassword> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> super::Result> { + pub(crate) async fn authenticate(self) -> super::Result> { let Scram(secret, ctx) = self.state; // pause the timer while we communicate with the client diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index 2ddf46fe25..8585b8ff48 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -1,5 +1,5 @@ //! Payload for ad hoc authentication method for clients that don't support SNI. -//! See the `impl` for [`super::backend::BackendType`]. +//! See the `impl` for [`super::backend::Backend`]. //! Read more: . //! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified. @@ -7,13 +7,13 @@ use bstr::ByteSlice; use crate::EndpointId; -pub struct PasswordHackPayload { - pub endpoint: EndpointId, - pub password: Vec, +pub(crate) struct PasswordHackPayload { + pub(crate) endpoint: EndpointId, + pub(crate) password: Vec, } impl PasswordHackPayload { - pub fn parse(bytes: &[u8]) -> Option { + pub(crate) fn parse(bytes: &[u8]) -> Option { // The format is `project=;` or `project=$`. let separators = [";", "$"]; for sep in separators { @@ -30,7 +30,7 @@ impl PasswordHackPayload { } } -pub fn parse_endpoint_param(bytes: &str) -> Option<&str> { +pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> { bytes .strip_prefix("project=") .or_else(|| bytes.strip_prefix("endpoint=")) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 8acba33bac..08effeff99 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -212,7 +212,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, - auth_backend: proxy::auth::BackendType::Local(proxy::auth::backend::MaybeOwned::Owned( + auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( LocalBackend::new(args.compute), )), metric_collection: None, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 1f45a33ed5..7706a1f7cd 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -60,11 +60,14 @@ use clap::{Parser, ValueEnum}; static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; #[derive(Clone, Debug, ValueEnum)] -enum AuthBackend { +enum AuthBackendType { Console, #[cfg(feature = "testing")] Postgres, - Link, + // clap only shows the name, not the alias, in usage text. + // TODO: swap name/alias and deprecate "link" + #[value(name("link"), alias("web"))] + Web, } /// Neon proxy/router @@ -77,8 +80,8 @@ struct ProxyCliArgs { /// listen for incoming client connections on ip:port #[clap(short, long, default_value = "127.0.0.1:4432")] proxy: String, - #[clap(value_enum, long, default_value_t = AuthBackend::Link)] - auth_backend: AuthBackend, + #[clap(value_enum, long, default_value_t = AuthBackendType::Web)] + auth_backend: AuthBackendType, /// listen for management callback connection on ip:port #[clap(short, long, default_value = "127.0.0.1:7000")] mgmt: String, @@ -88,7 +91,7 @@ struct ProxyCliArgs { /// listen for incoming wss connections on ip:port #[clap(long)] wss: Option, - /// redirect unauthenticated users to the given uri in case of link auth + /// redirect unauthenticated users to the given uri in case of web auth #[clap(short, long, default_value = "http://localhost:3000/psql_session/")] uri: String, /// cloud API endpoint for authenticating users @@ -470,7 +473,7 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::BackendType::Console(api, _) = &config.auth_backend { + if let auth::Backend::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} @@ -575,7 +578,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } let auth_backend = match &args.auth_backend { - AuthBackend::Console => { + AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let project_info_cache_config: ProjectInfoCacheOptions = args.project_info_cache.parse()?; @@ -624,18 +627,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_endpoint_rate_limiter, ); let api = console::provider::ConsoleBackend::Console(api); - auth::BackendType::Console(MaybeOwned::Owned(api), ()) + auth::Backend::Console(MaybeOwned::Owned(api), ()) } #[cfg(feature = "testing")] - AuthBackend::Postgres => { + AuthBackendType::Postgres => { let url = args.auth_endpoint.parse()?; let api = console::provider::mock::Api::new(url); let api = console::provider::ConsoleBackend::Postgres(api); - auth::BackendType::Console(MaybeOwned::Owned(api), ()) + auth::Backend::Console(MaybeOwned::Owned(api), ()) } - AuthBackend::Link => { + AuthBackendType::Web => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url), ()) + auth::Backend::Web(MaybeOwned::Owned(url), ()) } }; diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index d1d4087241..6c168144a7 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,7 +1,7 @@ -pub mod common; -pub mod endpoints; -pub mod project_info; +pub(crate) mod common; +pub(crate) mod endpoints; +pub(crate) mod project_info; mod timed_lru; -pub use common::{Cache, Cached}; -pub use timed_lru::TimedLru; +pub(crate) use common::{Cache, Cached}; +pub(crate) use timed_lru::TimedLru; diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index 82c78e3eb2..b5caf94788 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut}; /// A generic trait which exposes types of cache's key and value, /// as well as the notion of cache entry invalidation. /// This is useful for [`Cached`]. -pub trait Cache { +pub(crate) trait Cache { /// Entry's key. type Key; @@ -29,21 +29,21 @@ impl Cache for &C { } /// Wrapper for convenient entry invalidation. -pub struct Cached::Value> { +pub(crate) struct Cached::Value> { /// Cache + lookup info. - pub token: Option<(C, C::LookupInfo)>, + pub(crate) token: Option<(C, C::LookupInfo)>, /// The value itself. - pub value: V, + pub(crate) value: V, } impl Cached { /// Place any entry into this wrapper; invalidation will be a no-op. - pub fn new_uncached(value: V) -> Self { + pub(crate) fn new_uncached(value: V) -> Self { Self { token: None, value } } - pub fn take_value(self) -> (Cached, V) { + pub(crate) fn take_value(self) -> (Cached, V) { ( Cached { token: self.token, @@ -53,7 +53,7 @@ impl Cached { ) } - pub fn map(self, f: impl FnOnce(V) -> U) -> Cached { + pub(crate) fn map(self, f: impl FnOnce(V) -> U) -> Cached { Cached { token: self.token, value: f(self.value), @@ -61,7 +61,7 @@ impl Cached { } /// Drop this entry from a cache if it's still there. - pub fn invalidate(self) -> V { + pub(crate) fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { cache.invalidate(info); } @@ -69,7 +69,7 @@ impl Cached { } /// Tell if this entry is actually cached. - pub fn cached(&self) -> bool { + pub(crate) fn cached(&self) -> bool { self.token.is_some() } } diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 8c851790c2..f4762232d8 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -28,7 +28,7 @@ use crate::{ }; #[derive(Deserialize, Debug, Clone)] -pub struct ControlPlaneEventKey { +pub(crate) struct ControlPlaneEventKey { endpoint_created: Option, branch_created: Option, project_created: Option, @@ -56,7 +56,7 @@ pub struct EndpointsCache { } impl EndpointsCache { - pub fn new(config: EndpointCacheConfig) -> Self { + pub(crate) fn new(config: EndpointCacheConfig) -> Self { Self { limiter: Arc::new(Mutex::new(GlobalRateLimiter::new( config.limiter_info.clone(), @@ -68,7 +68,7 @@ impl EndpointsCache { ready: AtomicBool::new(false), } } - pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { + pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool { if !self.ready.load(Ordering::Acquire) { return true; } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index eda886a7af..ceae74a9a0 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -24,7 +24,7 @@ use crate::{ use super::{Cache, Cached}; #[async_trait] -pub trait ProjectInfoCache { +pub(crate) trait ProjectInfoCache { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); @@ -37,7 +37,7 @@ struct Entry { } impl Entry { - pub fn new(value: T) -> Self { + pub(crate) fn new(value: T) -> Self { Self { created_at: Instant::now(), value, @@ -64,7 +64,7 @@ impl EndpointInfo { Some(t) => t < created_at, } } - pub fn get_role_secret( + pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, @@ -81,7 +81,7 @@ impl EndpointInfo { None } - pub fn get_allowed_ips( + pub(crate) fn get_allowed_ips( &self, valid_since: Instant, ignore_cache_since: Option, @@ -96,10 +96,10 @@ impl EndpointInfo { } None } - pub fn invalidate_allowed_ips(&mut self) { + pub(crate) fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } - pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { self.secret.remove(&role_name); } } @@ -178,7 +178,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } impl ProjectInfoCacheImpl { - pub fn new(config: ProjectInfoCacheOptions) -> Self { + pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { Self { cache: DashMap::new(), project2ep: DashMap::new(), @@ -189,7 +189,7 @@ impl ProjectInfoCacheImpl { } } - pub fn get_role_secret( + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, @@ -212,7 +212,7 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } - pub fn get_allowed_ips( + pub(crate) fn get_allowed_ips( &self, endpoint_id: &EndpointId, ) -> Option>>> { @@ -230,7 +230,7 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } - pub fn insert_role_secret( + pub(crate) fn insert_role_secret( &self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, @@ -247,7 +247,7 @@ impl ProjectInfoCacheImpl { entry.secret.insert(role_name, secret.into()); } } - pub fn insert_allowed_ips( + pub(crate) fn insert_allowed_ips( &self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, @@ -319,7 +319,7 @@ impl ProjectInfoCacheImpl { /// Lookup info for project info cache. /// This is used to invalidate cache entries. -pub struct CachedLookupInfo { +pub(crate) struct CachedLookupInfo { /// Search by this key. endpoint_id: EndpointIdInt, lookup_type: LookupType, diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index 07fad56643..8bb482f7c6 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -39,7 +39,7 @@ use super::{common::Cached, *}; /// /// * It's possible for an entry that has not yet expired entry to be evicted /// before expired items. That's a bit wasteful, but probably fine in practice. -pub struct TimedLru { +pub(crate) struct TimedLru { /// Cache's name for tracing. name: &'static str, @@ -72,7 +72,7 @@ struct Entry { impl TimedLru { /// Construct a new LRU cache with timed entries. - pub fn new( + pub(crate) fn new( name: &'static str, capacity: usize, ttl: Duration, @@ -207,11 +207,11 @@ impl TimedLru { } impl TimedLru { - pub fn insert_ttl(&self, key: K, value: V, ttl: Duration) { + pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) { self.insert_raw_ttl(key, value, ttl, false); } - pub fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (created_at, old) = self.insert_raw(key.clone(), value); let cached = Cached { @@ -221,22 +221,11 @@ impl TimedLru { (old, cached) } - - pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { - let (created_at, old) = self.insert_raw(key.clone(), value.clone()); - - let cached = Cached { - token: Some((self, LookupInfo { created_at, key })), - value, - }; - - (old, cached) - } } impl TimedLru { /// Retrieve a cached entry in convenient wrapper. - pub fn get(&self, key: &Q) -> Option> + pub(crate) fn get(&self, key: &Q) -> Option> where K: Borrow + Clone, Q: Hash + Eq + ?Sized, @@ -253,32 +242,10 @@ impl TimedLru { } }) } - - /// Retrieve a cached entry in convenient wrapper, ignoring its TTL. - pub fn get_ignoring_ttl(&self, key: &Q) -> Option> - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let mut cache = self.cache.lock(); - cache - .get(key) - .map(|entry| Cached::new_uncached(entry.value.clone())) - } - - /// Remove an entry from the cache. - pub fn remove(&self, key: &Q) -> Option - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - let mut cache = self.cache.lock(); - cache.remove(key).map(|entry| entry.value) - } } /// Lookup information for key invalidation. -pub struct LookupInfo { +pub(crate) struct LookupInfo { /// Time of creation of a cache [`Entry`]. /// We use this during invalidation lookups to prevent eviction of a newer /// entry sharing the same key (it might've been inserted by a different diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index ea8f7b4070..71a2a16af8 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -18,7 +18,7 @@ use crate::{ pub type CancelMap = Arc>>; pub type CancellationHandlerMain = CancellationHandler>>>; -pub type CancellationHandlerMainInternal = Option>>; +pub(crate) type CancellationHandlerMainInternal = Option>>; /// Enables serving `CancelRequest`s. /// @@ -32,7 +32,7 @@ pub struct CancellationHandler

{ } #[derive(Debug, Error)] -pub enum CancelError { +pub(crate) enum CancelError { #[error("{0}")] IO(#[from] std::io::Error), #[error("{0}")] @@ -53,7 +53,7 @@ impl ReportableError for CancelError { impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub fn get_session(self: Arc) -> Session

{ + pub(crate) fn get_session(self: Arc) -> Session

{ // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -81,7 +81,7 @@ impl CancellationHandler

{ } /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. - pub async fn cancel_session( + pub(crate) async fn cancel_session( &self, key: CancelKeyData, session_id: Uuid, @@ -155,14 +155,14 @@ pub struct CancelClosure { } impl CancelClosure { - pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { + pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { Self { socket_addr, cancel_token, } } /// Cancels the query running on user's compute node. - pub async fn try_cancel_query(self) -> Result<(), CancelError> { + pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; self.cancel_token.cancel_query_raw(socket, NoTls).await?; info!("query was cancelled"); @@ -171,7 +171,7 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session

{ +pub(crate) struct Session

{ /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. @@ -181,7 +181,7 @@ pub struct Session

{ impl

Session

{ /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. - pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { + pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancellation_handler .map diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index b6659f5dd0..8d3cb8ee3c 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -23,7 +23,7 @@ use tracing::{error, info, warn}; const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] -pub enum ConnectionError { +pub(crate) enum ConnectionError { /// This error doesn't seem to reveal any secrets; for instance, /// `tokio_postgres::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] @@ -86,22 +86,22 @@ impl ReportableError for ConnectionError { } /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. -pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; +pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. #[derive(Clone, Default)] -pub struct ConnCfg(Box); +pub(crate) struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self::default() } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: Self) { + pub(crate) fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -111,7 +111,7 @@ impl ConnCfg { } } - pub fn get_host(&self) -> Result { + pub(crate) fn get_host(&self) -> Result { match self.0.get_hosts() { [tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()), // we should not have multiple address or unix addresses. @@ -122,15 +122,15 @@ impl ConnCfg { } /// Apply startup message params to the connection config. - pub fn set_startup_params(&mut self, params: &StartupMessageParams) { + pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) { // Only set `user` if it's not present in the config. - // Link auth flow takes username from the console's response. + // Web auth flow takes username from the console's response. if let (None, Some(user)) = (self.get_user(), params.get("user")) { self.user(user); } // Only set `dbname` if it's not present in the config. - // Link auth flow takes dbname from the console's response. + // Web auth flow takes dbname from the console's response. if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { self.dbname(dbname); } @@ -255,25 +255,25 @@ impl ConnCfg { } } -pub struct PostgresConnection { +pub(crate) struct PostgresConnection { /// Socket connected to a compute node. - pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream< + pub(crate) stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream< tokio::net::TcpStream, tokio_postgres_rustls::RustlsStream, >, /// PostgreSQL connection parameters. - pub params: std::collections::HashMap, + pub(crate) params: std::collections::HashMap, /// Query cancellation token. - pub cancel_closure: CancelClosure, + pub(crate) cancel_closure: CancelClosure, /// Labels for proxy's metrics. - pub aux: MetricsAuxInfo, + pub(crate) aux: MetricsAuxInfo, _guage: NumDbConnectionsGuard<'static>, } impl ConnCfg { /// Connect to a corresponding compute node. - pub async fn connect( + pub(crate) async fn connect( &self, ctx: &RequestMonitoring, allow_self_signed_compute: bool, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6c42fb8d19..d7fc6eee22 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -25,7 +25,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, (), ()>, + pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, @@ -247,7 +247,7 @@ impl CertResolver { let common_name = pem.subject().to_string(); - // We only use non-wildcard certificates in link proxy so it seems okay to treat them the same as + // We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as // wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so // verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names // and passed None instead, which blows up number of cases downstream code should handle. Proper coding diff --git a/proxy/src/console.rs b/proxy/src/console.rs index ea95e83437..87d8e781aa 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -10,7 +10,7 @@ pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; /// Various cache-related types. pub mod caches { - pub use super::provider::{ApiCaches, NodeInfoCache}; + pub use super::provider::ApiCaches; } /// Various cache-related types. diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index a7ccf076b0..a48c7316f6 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -12,22 +12,22 @@ use crate::RoleName; /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize, Clone)] -pub struct ConsoleError { - pub error: Box, +pub(crate) struct ConsoleError { + pub(crate) error: Box, #[serde(skip)] - pub http_status_code: http::StatusCode, - pub status: Option, + pub(crate) http_status_code: http::StatusCode, + pub(crate) status: Option, } impl ConsoleError { - pub fn get_reason(&self) -> Reason { + pub(crate) fn get_reason(&self) -> Reason { self.status .as_ref() .and_then(|s| s.details.error_info.as_ref()) .map_or(Reason::Unknown, |e| e.reason) } - pub fn get_user_facing_message(&self) -> String { + pub(crate) fn get_user_facing_message(&self) -> String { use super::provider::errors::REQUEST_FAILED; self.status .as_ref() @@ -88,27 +88,28 @@ impl CouldRetry for ConsoleError { } #[derive(Debug, Deserialize, Clone)] -pub struct Status { - pub code: Box, - pub message: Box, - pub details: Details, +#[allow(dead_code)] +pub(crate) struct Status { + pub(crate) code: Box, + pub(crate) message: Box, + pub(crate) details: Details, } #[derive(Debug, Deserialize, Clone)] -pub struct Details { - pub error_info: Option, - pub retry_info: Option, - pub user_facing_message: Option, +pub(crate) struct Details { + pub(crate) error_info: Option, + pub(crate) retry_info: Option, + pub(crate) user_facing_message: Option, } #[derive(Copy, Clone, Debug, Deserialize)] -pub struct ErrorInfo { - pub reason: Reason, +pub(crate) struct ErrorInfo { + pub(crate) reason: Reason, // Schema could also have `metadata` field, but it's not structured. Skip it for now. } #[derive(Clone, Copy, Debug, Deserialize, Default)] -pub enum Reason { +pub(crate) enum Reason { /// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles. #[serde(rename = "ROLE_PROTECTED")] RoleProtected, @@ -168,7 +169,7 @@ pub enum Reason { } impl Reason { - pub fn is_not_found(&self) -> bool { + pub(crate) fn is_not_found(self) -> bool { matches!( self, Reason::ResourceNotFound @@ -178,7 +179,7 @@ impl Reason { ) } - pub fn can_retry(&self) -> bool { + pub(crate) fn can_retry(self) -> bool { match self { // do not retry role protected errors // not a transitive error @@ -208,22 +209,23 @@ impl Reason { } #[derive(Copy, Clone, Debug, Deserialize)] -pub struct RetryInfo { - pub retry_delay_ms: u64, +#[allow(dead_code)] +pub(crate) struct RetryInfo { + pub(crate) retry_delay_ms: u64, } #[derive(Debug, Deserialize, Clone)] -pub struct UserFacingMessage { - pub message: Box, +pub(crate) struct UserFacingMessage { + pub(crate) message: Box, } /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`]. /// Returned by the `/proxy_get_role_secret` API method. #[derive(Deserialize)] -pub struct GetRoleSecret { - pub role_secret: Box, - pub allowed_ips: Option>, - pub project_id: Option, +pub(crate) struct GetRoleSecret { + pub(crate) role_secret: Box, + pub(crate) allowed_ips: Option>, + pub(crate) project_id: Option, } // Manually implement debug to omit sensitive info. @@ -236,21 +238,21 @@ impl fmt::Debug for GetRoleSecret { /// Response which holds compute node's `host:port` pair. /// Returned by the `/proxy_wake_compute` API method. #[derive(Debug, Deserialize)] -pub struct WakeCompute { - pub address: Box, - pub aux: MetricsAuxInfo, +pub(crate) struct WakeCompute { + pub(crate) address: Box, + pub(crate) aux: MetricsAuxInfo, } -/// Async response which concludes the link auth flow. +/// Async response which concludes the web auth flow. /// Also known as `kickResponse` in the console. #[derive(Debug, Deserialize)] -pub struct KickSession<'a> { +pub(crate) struct KickSession<'a> { /// Session ID is assigned by the proxy. - pub session_id: &'a str, + pub(crate) session_id: &'a str, /// Compute node connection params. #[serde(deserialize_with = "KickSession::parse_db_info")] - pub result: DatabaseInfo, + pub(crate) result: DatabaseInfo, } impl KickSession<'_> { @@ -273,15 +275,15 @@ impl KickSession<'_> { /// Compute node connection params. #[derive(Deserialize)] -pub struct DatabaseInfo { - pub host: Box, - pub port: u16, - pub dbname: Box, - pub user: Box, +pub(crate) struct DatabaseInfo { + pub(crate) host: Box, + pub(crate) port: u16, + pub(crate) dbname: Box, + pub(crate) user: Box, /// Console always provides a password, but it might /// be inconvenient for debug with local PG instance. - pub password: Option>, - pub aux: MetricsAuxInfo, + pub(crate) password: Option>, + pub(crate) aux: MetricsAuxInfo, } // Manually implement debug to omit sensitive info. @@ -299,12 +301,12 @@ impl fmt::Debug for DatabaseInfo { /// Various labels for prometheus metrics. /// Also known as `ProxyMetricsAuxInfo` in the console. #[derive(Debug, Deserialize, Clone)] -pub struct MetricsAuxInfo { - pub endpoint_id: EndpointIdInt, - pub project_id: ProjectIdInt, - pub branch_id: BranchIdInt, +pub(crate) struct MetricsAuxInfo { + pub(crate) endpoint_id: EndpointIdInt, + pub(crate) project_id: ProjectIdInt, + pub(crate) branch_id: BranchIdInt, #[serde(default)] - pub cold_start_info: ColdStartInfo, + pub(crate) cold_start_info: ColdStartInfo, } #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)] @@ -331,7 +333,7 @@ pub enum ColdStartInfo { } impl ColdStartInfo { - pub fn as_str(&self) -> &'static str { + pub(crate) fn as_str(self) -> &'static str { match self { ColdStartInfo::Unknown => "unknown", ColdStartInfo::Warm => "warm", diff --git a/proxy/src/console/mgmt.rs b/proxy/src/console/mgmt.rs index 82d5033aab..2ed4f5f206 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/console/mgmt.rs @@ -14,18 +14,18 @@ use tracing::{error, info, info_span, Instrument}; static CPLANE_WAITERS: Lazy> = Lazy::new(Default::default); /// Give caller an opportunity to wait for the cloud's reply. -pub fn get_waiter( +pub(crate) fn get_waiter( psql_session_id: impl Into, ) -> Result, waiters::RegisterError> { CPLANE_WAITERS.register(psql_session_id.into()) } -pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { +pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> { CPLANE_WAITERS.notify(psql_session_id, msg) } /// Console management API listener task. -/// It spawns console response handlers needed for the link auth. +/// It spawns console response handlers needed for the web auth. pub async fn task_main(listener: TcpListener) -> anyhow::Result { scopeguard::defer! { info!("mgmt has shut down"); @@ -74,7 +74,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> { } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = DatabaseInfo; +pub(crate) type ComputeReady = DatabaseInfo; // TODO: replace with an http-based protocol. struct MgmtHandler; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 4794527410..12a6e2f12a 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -23,7 +23,7 @@ use std::{hash::Hash, sync::Arc, time::Duration}; use tokio::time::Instant; use tracing::info; -pub mod errors { +pub(crate) mod errors { use crate::{ console::messages::{self, ConsoleError, Reason}, error::{io_error, ErrorKind, ReportableError, UserFacingError}, @@ -34,11 +34,11 @@ pub mod errors { use super::ApiLockError; /// A go-to error message which doesn't leak any detail. - pub const REQUEST_FAILED: &str = "Console request failed"; + pub(crate) const REQUEST_FAILED: &str = "Console request failed"; /// Common console API error. #[derive(Debug, Error)] - pub enum ApiError { + pub(crate) enum ApiError { /// Error returned by the console itself. #[error("{REQUEST_FAILED} with {0}")] Console(ConsoleError), @@ -50,7 +50,7 @@ pub mod errors { impl ApiError { /// Returns HTTP status code if it's the reason for failure. - pub fn get_reason(&self) -> messages::Reason { + pub(crate) fn get_reason(&self) -> messages::Reason { match self { ApiError::Console(e) => e.get_reason(), ApiError::Transport(_) => messages::Reason::Unknown, @@ -146,7 +146,7 @@ pub mod errors { } #[derive(Debug, Error)] - pub enum GetAuthInfoError { + pub(crate) enum GetAuthInfoError { // We shouldn't include the actual secret here. #[error("Console responded with a malformed auth secret")] BadSecret, @@ -183,7 +183,7 @@ pub mod errors { } #[derive(Debug, Error)] - pub enum WakeComputeError { + pub(crate) enum WakeComputeError { #[error("Console responded with a malformed compute address: {0}")] BadComputeAddress(Box), @@ -247,7 +247,7 @@ pub mod errors { /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] -pub enum AuthSecret { +pub(crate) enum AuthSecret { #[cfg(any(test, feature = "testing"))] /// Md5 hash of user's password. Md5([u8; 16]), @@ -257,32 +257,32 @@ pub enum AuthSecret { } #[derive(Default)] -pub struct AuthInfo { - pub secret: Option, +pub(crate) struct AuthInfo { + pub(crate) secret: Option, /// List of IP addresses allowed for the autorization. - pub allowed_ips: Vec, + pub(crate) allowed_ips: Vec, /// Project ID. This is used for cache invalidation. - pub project_id: Option, + pub(crate) project_id: Option, } /// Info for establishing a connection to a compute node. /// This is what we get after auth succeeded, but not before! #[derive(Clone)] -pub struct NodeInfo { +pub(crate) struct NodeInfo { /// Compute node connection params. /// It's sad that we have to clone this, but this will improve /// once we migrate to a bespoke connection logic. - pub config: compute::ConnCfg, + pub(crate) config: compute::ConnCfg, /// Labels for proxy's metrics. - pub aux: MetricsAuxInfo, + pub(crate) aux: MetricsAuxInfo, /// Whether we should accept self-signed certificates (for testing) - pub allow_self_signed_compute: bool, + pub(crate) allow_self_signed_compute: bool, } impl NodeInfo { - pub async fn connect( + pub(crate) async fn connect( &self, ctx: &RequestMonitoring, timeout: Duration, @@ -296,12 +296,12 @@ impl NodeInfo { ) .await } - pub fn reuse_settings(&mut self, other: Self) { + pub(crate) fn reuse_settings(&mut self, other: Self) { self.allow_self_signed_compute = other.allow_self_signed_compute; self.config.reuse_password(other.config); } - pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) { match keys { ComputeCredentialKeys::Password(password) => self.config.password(password), ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), @@ -310,10 +310,10 @@ impl NodeInfo { } } -pub type NodeInfoCache = TimedLru>>; -pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; +pub(crate) type NodeInfoCache = TimedLru>>; +pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; +pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; +pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -350,6 +350,7 @@ pub enum ConsoleBackend { Postgres(mock::Api), /// Internal testing #[cfg(test)] + #[allow(private_interfaces)] Test(Box), } @@ -402,7 +403,7 @@ impl Api for ConsoleBackend { /// Various caches for [`console`](super). pub struct ApiCaches { /// Cache for the `wake_compute` API method. - pub node_info: NodeInfoCache, + pub(crate) node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. pub project_info: Arc, /// List of all valid endpoints. @@ -439,7 +440,7 @@ pub struct ApiLocks { } #[derive(Debug, thiserror::Error)] -pub enum ApiLockError { +pub(crate) enum ApiLockError { #[error("timeout acquiring resource permit")] TimeoutError(#[from] tokio::time::error::Elapsed), } @@ -471,7 +472,7 @@ impl ApiLocks { }) } - pub async fn get_permit(&self, key: &K) -> Result { + pub(crate) async fn get_permit(&self, key: &K) -> Result { if self.config.initial_limit == 0 { return Ok(WakeComputePermit { permit: Token::disabled(), @@ -531,18 +532,18 @@ impl ApiLocks { } } -pub struct WakeComputePermit { +pub(crate) struct WakeComputePermit { permit: Token, } impl WakeComputePermit { - pub fn should_check_cache(&self) -> bool { + pub(crate) fn should_check_cache(&self) -> bool { !self.permit.is_disabled() } - pub fn release(self, outcome: Outcome) { + pub(crate) fn release(self, outcome: Outcome) { self.permit.release(outcome); } - pub fn release_result(self, res: Result) -> Result { + pub(crate) fn release_result(self, res: Result) -> Result { match res { Ok(_) => self.release(Outcome::Success), Err(_) => self.release(Outcome::Overload), diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 4e8b7a9365..08b87cd87a 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -48,7 +48,7 @@ impl Api { Self { endpoint } } - pub fn url(&self) -> &str { + pub(crate) fn url(&self) -> &str { self.endpoint.as_str() } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index a6c0e233fc..33eda72e65 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -25,8 +25,8 @@ use tracing::{debug, error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, - pub locks: &'static ApiLocks, - pub wake_compute_endpoint_rate_limiter: Arc, + pub(crate) locks: &'static ApiLocks, + pub(crate) wake_compute_endpoint_rate_limiter: Arc, jwt: String, } @@ -51,7 +51,7 @@ impl Api { } } - pub fn url(&self) -> &str { + pub(crate) fn url(&self) -> &str { self.endpoint.url().as_str() } diff --git a/proxy/src/context.rs b/proxy/src/context.rs index cafbdedc15..72e1fa1cee 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -22,8 +22,9 @@ use self::parquet::RequestData; pub mod parquet; -pub static LOG_CHAN: OnceCell> = OnceCell::new(); -pub static LOG_CHAN_DISCONNECT: OnceCell> = OnceCell::new(); +pub(crate) static LOG_CHAN: OnceCell> = OnceCell::new(); +pub(crate) static LOG_CHAN_DISCONNECT: OnceCell> = + OnceCell::new(); /// Context data for a single request to connect to a database. /// @@ -38,12 +39,12 @@ pub struct RequestMonitoring( ); struct RequestMonitoringInner { - pub peer_addr: IpAddr, - pub session_id: Uuid, - pub protocol: Protocol, + pub(crate) peer_addr: IpAddr, + pub(crate) session_id: Uuid, + pub(crate) protocol: Protocol, first_packet: chrono::DateTime, region: &'static str, - pub span: Span, + pub(crate) span: Span, // filled in as they are discovered project: Option, @@ -63,15 +64,15 @@ struct RequestMonitoringInner { sender: Option>, // This sender is only used to log the length of session in case of success. disconnect_sender: Option>, - pub latency_timer: LatencyTimer, + pub(crate) latency_timer: LatencyTimer, // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane. rejected: Option, disconnect_timestamp: Option>, } #[derive(Clone, Debug)] -pub enum AuthMethod { - // aka link aka passwordless +pub(crate) enum AuthMethod { + // aka passwordless, fka link Web, ScramSha256, ScramSha256Plus, @@ -125,11 +126,11 @@ impl RequestMonitoring { } #[cfg(test)] - pub fn test() -> Self { + pub(crate) fn test() -> Self { RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test") } - pub fn console_application_name(&self) -> String { + pub(crate) fn console_application_name(&self) -> String { let this = self.0.try_lock().expect("should not deadlock"); format!( "{}/{}", @@ -138,19 +139,19 @@ impl RequestMonitoring { ) } - pub fn set_rejected(&self, rejected: bool) { + pub(crate) fn set_rejected(&self, rejected: bool) { let mut this = self.0.try_lock().expect("should not deadlock"); this.rejected = Some(rejected); } - pub fn set_cold_start_info(&self, info: ColdStartInfo) { + pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) { self.0 .try_lock() .expect("should not deadlock") .set_cold_start_info(info); } - pub fn set_db_options(&self, options: StartupMessageParams) { + pub(crate) fn set_db_options(&self, options: StartupMessageParams) { let mut this = self.0.try_lock().expect("should not deadlock"); this.set_application(options.get("application_name").map(SmolStr::from)); if let Some(user) = options.get("user") { @@ -163,7 +164,7 @@ impl RequestMonitoring { this.pg_options = Some(options); } - pub fn set_project(&self, x: MetricsAuxInfo) { + pub(crate) fn set_project(&self, x: MetricsAuxInfo) { let mut this = self.0.try_lock().expect("should not deadlock"); if this.endpoint_id.is_none() { this.set_endpoint_id(x.endpoint_id.as_str().into()); @@ -173,33 +174,33 @@ impl RequestMonitoring { this.set_cold_start_info(x.cold_start_info); } - pub fn set_project_id(&self, project_id: ProjectIdInt) { + pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) { let mut this = self.0.try_lock().expect("should not deadlock"); this.project = Some(project_id); } - pub fn set_endpoint_id(&self, endpoint_id: EndpointId) { + pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) { self.0 .try_lock() .expect("should not deadlock") .set_endpoint_id(endpoint_id); } - pub fn set_dbname(&self, dbname: DbName) { + pub(crate) fn set_dbname(&self, dbname: DbName) { self.0 .try_lock() .expect("should not deadlock") .set_dbname(dbname); } - pub fn set_user(&self, user: RoleName) { + pub(crate) fn set_user(&self, user: RoleName) { self.0 .try_lock() .expect("should not deadlock") .set_user(user); } - pub fn set_auth_method(&self, auth_method: AuthMethod) { + pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) { let mut this = self.0.try_lock().expect("should not deadlock"); this.auth_method = Some(auth_method); } @@ -211,7 +212,7 @@ impl RequestMonitoring { .has_private_peer_addr() } - pub fn set_error_kind(&self, kind: ErrorKind) { + pub(crate) fn set_error_kind(&self, kind: ErrorKind) { let mut this = self.0.try_lock().expect("should not deadlock"); // Do not record errors from the private address to metrics. if !this.has_private_peer_addr() { @@ -237,30 +238,30 @@ impl RequestMonitoring { .log_connect(); } - pub fn protocol(&self) -> Protocol { + pub(crate) fn protocol(&self) -> Protocol { self.0.try_lock().expect("should not deadlock").protocol } - pub fn span(&self) -> Span { + pub(crate) fn span(&self) -> Span { self.0.try_lock().expect("should not deadlock").span.clone() } - pub fn session_id(&self) -> Uuid { + pub(crate) fn session_id(&self) -> Uuid { self.0.try_lock().expect("should not deadlock").session_id } - pub fn peer_addr(&self) -> IpAddr { + pub(crate) fn peer_addr(&self) -> IpAddr { self.0.try_lock().expect("should not deadlock").peer_addr } - pub fn cold_start_info(&self) -> ColdStartInfo { + pub(crate) fn cold_start_info(&self) -> ColdStartInfo { self.0 .try_lock() .expect("should not deadlock") .cold_start_info } - pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> { + pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> { LatencyTimerPause { ctx: self, start: tokio::time::Instant::now(), @@ -268,7 +269,7 @@ impl RequestMonitoring { } } - pub fn success(&self) { + pub(crate) fn success(&self) { self.0 .try_lock() .expect("should not deadlock") @@ -277,7 +278,7 @@ impl RequestMonitoring { } } -pub struct LatencyTimerPause<'a> { +pub(crate) struct LatencyTimerPause<'a> { ctx: &'a RequestMonitoring, start: tokio::time::Instant, waiting_for: Waiting, diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index e5962b35fa..c6f83fd069 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -62,8 +62,8 @@ pub struct ParquetUploadArgs { // But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN // level instead, as repeated failures can mean a more serious problem. If it // fails more than FAILED_UPLOAD_RETRIES times, we give up -pub const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3; -pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; +pub(crate) const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3; +pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; // the parquet crate leaves a lot to be desired... // what follows is an attempt to write parquet files with minimal allocs. @@ -73,7 +73,7 @@ pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; // * after each rowgroup write, we check the length of the file and upload to s3 if large enough #[derive(parquet_derive::ParquetRecordWriter)] -pub struct RequestData { +pub(crate) struct RequestData { region: &'static str, protocol: &'static str, /// Must be UTC. The derive macro doesn't like the timezones @@ -613,40 +613,6 @@ mod tests { tmpdir.close().unwrap(); } - #[tokio::test] - async fn verify_parquet_min_compression() { - let tmpdir = camino_tempfile::tempdir().unwrap(); - - let config = ParquetConfig { - propeties: Arc::new( - WriterProperties::builder() - .set_compression(parquet::basic::Compression::ZSTD(ZstdLevel::default())) - .build(), - ), - rows_per_group: 2_000, - file_size: 1_000_000, - max_duration: time::Duration::from_secs(20 * 60), - test_remote_failures: 0, - }; - - let rx = random_stream(50_000); - let file_stats = run_test(tmpdir.path(), config, rx).await; - - // with compression, there are fewer files with more rows per file - assert_eq!( - file_stats, - [ - (1223214, 5, 10000), - (1229364, 5, 10000), - (1231158, 5, 10000), - (1230520, 5, 10000), - (1221798, 5, 10000) - ] - ); - - tmpdir.close().unwrap(); - } - #[tokio::test] async fn verify_parquet_strong_compression() { let tmpdir = camino_tempfile::tempdir().unwrap(); diff --git a/proxy/src/error.rs b/proxy/src/error.rs index fdfe50a494..53f9f75c5b 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -3,12 +3,12 @@ use std::{error::Error as StdError, fmt, io}; use measured::FixedCardinalityLabel; /// Upcast (almost) any error into an opaque [`io::Error`]. -pub fn io_error(e: impl Into>) -> io::Error { +pub(crate) fn io_error(e: impl Into>) -> io::Error { io::Error::new(io::ErrorKind::Other, e) } /// A small combinator for pluggable error logging. -pub fn log_error(e: E) -> E { +pub(crate) fn log_error(e: E) -> E { tracing::error!("{e}"); e } @@ -19,7 +19,7 @@ pub fn log_error(e: E) -> E { /// NOTE: This trait should not be implemented for [`anyhow::Error`], since it /// is way too convenient and tends to proliferate all across the codebase, /// ultimately leading to accidental leaks of sensitive data. -pub trait UserFacingError: ReportableError { +pub(crate) trait UserFacingError: ReportableError { /// Format the error for client, stripping all sensitive info. /// /// Although this might be a no-op for many types, it's highly @@ -64,7 +64,7 @@ pub enum ErrorKind { } impl ErrorKind { - pub fn to_metric_label(&self) -> &'static str { + pub(crate) fn to_metric_label(self) -> &'static str { match self { ErrorKind::User => "user", ErrorKind::ClientDisconnect => "clientdisconnect", @@ -78,7 +78,7 @@ impl ErrorKind { } } -pub trait ReportableError: fmt::Display + Send + 'static { +pub(crate) trait ReportableError: fmt::Display + Send + 'static { fn get_error_kind(&self) -> ErrorKind; } diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 1f1dd8c415..fee634f67f 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -12,9 +12,9 @@ use http_body_util::BodyExt; use hyper1::body::Body; use serde::de::DeserializeOwned; -pub use reqwest::{Request, Response, StatusCode}; -pub use reqwest_middleware::{ClientWithMiddleware, Error}; -pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +pub(crate) use reqwest::{Request, Response}; +pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error}; +pub(crate) use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use crate::{ metrics::{ConsoleRequest, Metrics}, @@ -35,7 +35,7 @@ pub fn new_client() -> ClientWithMiddleware { .build() } -pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { +pub(crate) fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { let timeout_client = reqwest::ClientBuilder::new() .timeout(default_timout) .build() @@ -77,20 +77,20 @@ impl Endpoint { } #[inline(always)] - pub fn url(&self) -> &ApiUrl { + pub(crate) fn url(&self) -> &ApiUrl { &self.endpoint } /// Return a [builder](RequestBuilder) for a `GET` request, /// appending a single `path` segment to the base endpoint URL. - pub fn get(&self, path: &str) -> RequestBuilder { + pub(crate) fn get(&self, path: &str) -> RequestBuilder { let mut url = self.endpoint.clone(); url.path_segments_mut().push(path); self.client.get(url.into_inner()) } /// Execute a [request](reqwest::Request). - pub async fn execute(&self, request: Request) -> Result { + pub(crate) async fn execute(&self, request: Request) -> Result { let _timer = Metrics::get() .proxy .console_request_latency @@ -102,7 +102,7 @@ impl Endpoint { } } -pub async fn parse_json_body_with_limit( +pub(crate) async fn parse_json_body_with_limit( mut b: impl Body + Unpin, limit: usize, ) -> anyhow::Result { diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index d418caa511..e5144cfe2e 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -29,10 +29,10 @@ impl std::fmt::Display for InternedString { } impl InternedString { - pub fn as_str(&self) -> &'static str { + pub(crate) fn as_str(&self) -> &'static str { Id::get_interner().inner.resolve(&self.inner) } - pub fn get(s: &str) -> Option { + pub(crate) fn get(s: &str) -> Option { Id::get_interner().get(s) } } @@ -78,7 +78,7 @@ impl serde::Serialize for InternedString { } impl StringInterner { - pub fn new() -> Self { + pub(crate) fn new() -> Self { StringInterner { inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher( Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()), @@ -90,26 +90,24 @@ impl StringInterner { } } - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - pub fn len(&self) -> usize { + #[cfg(test)] + fn len(&self) -> usize { self.inner.len() } - pub fn current_memory_usage(&self) -> usize { + #[cfg(test)] + fn current_memory_usage(&self) -> usize { self.inner.current_memory_usage() } - pub fn get_or_intern(&self, s: &str) -> InternedString { + pub(crate) fn get_or_intern(&self, s: &str) -> InternedString { InternedString { inner: self.inner.get_or_intern(s), _id: PhantomData, } } - pub fn get(&self, s: &str) -> Option> { + pub(crate) fn get(&self, s: &str) -> Option> { Some(InternedString { inner: self.inner.get(s)?, _id: PhantomData, @@ -132,14 +130,14 @@ impl Default for StringInterner { } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub struct RoleNameTag; +pub(crate) struct RoleNameTag; impl InternId for RoleNameTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } -pub type RoleNameInt = InternedString; +pub(crate) type RoleNameInt = InternedString; impl From<&RoleName> for RoleNameInt { fn from(value: &RoleName) -> Self { RoleNameTag::get_interner().get_or_intern(value) @@ -150,7 +148,7 @@ impl From<&RoleName> for RoleNameInt { pub struct EndpointIdTag; impl InternId for EndpointIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -170,7 +168,7 @@ impl From for EndpointIdInt { pub struct BranchIdTag; impl InternId for BranchIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -190,7 +188,7 @@ impl From for BranchIdInt { pub struct ProjectIdTag; impl InternId for ProjectIdTag { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } @@ -217,7 +215,7 @@ mod tests { struct MyId; impl InternId for MyId { fn get_interner() -> &'static StringInterner { - pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + pub(crate) static ROLE_NAMES: OnceLock> = OnceLock::new(); ROLE_NAMES.get_or_init(Default::default) } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 1e14ca59ec..8d7e586b3d 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -157,7 +157,8 @@ macro_rules! smol_str_wrapper { pub struct $name(smol_str::SmolStr); impl $name { - pub fn as_str(&self) -> &str { + #[allow(unused)] + pub(crate) fn as_str(&self) -> &str { self.0.as_str() } } @@ -252,19 +253,19 @@ smol_str_wrapper!(Host); // Endpoints are a bit tricky. Rare they might be branches or projects. impl EndpointId { - pub fn is_endpoint(&self) -> bool { + pub(crate) fn is_endpoint(&self) -> bool { self.0.starts_with("ep-") } - pub fn is_branch(&self) -> bool { + pub(crate) fn is_branch(&self) -> bool { self.0.starts_with("br-") } - pub fn is_project(&self) -> bool { - !self.is_endpoint() && !self.is_branch() - } - pub fn as_branch(&self) -> BranchId { + // pub(crate) fn is_project(&self) -> bool { + // !self.is_endpoint() && !self.is_branch() + // } + pub(crate) fn as_branch(&self) -> BranchId { BranchId(self.0.clone()) } - pub fn as_project(&self) -> ProjectId { + pub(crate) fn as_project(&self) -> ProjectId { ProjectId(self.0.clone()) } } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index ccef88231b..2da7eac580 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -4,8 +4,8 @@ use lasso::ThreadedRodeo; use measured::{ label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet}, metric::{histogram::Thresholds, name::MetricName}, - Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, - LabelGroup, MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, + MetricGroup, }; use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec}; @@ -548,6 +548,7 @@ pub enum RedisEventsCount { } pub struct ThreadPoolWorkers(usize); +#[derive(Copy, Clone)] pub struct ThreadPoolWorkerId(pub usize); impl LabelValue for ThreadPoolWorkerId { @@ -613,9 +614,6 @@ impl FixedCardinalitySet for ThreadPoolWorkers { #[derive(MetricGroup)] #[metric(new(workers: usize))] pub struct ThreadPoolMetrics { - pub injector_queue_depth: Gauge, - #[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))] - pub worker_queue_depth: GaugeVec, #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] pub worker_task_turns_total: CounterVec, #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] diff --git a/proxy/src/parse.rs b/proxy/src/parse.rs index 0d03574901..8c0f251066 100644 --- a/proxy/src/parse.rs +++ b/proxy/src/parse.rs @@ -2,14 +2,14 @@ use std::ffi::CStr; -pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { +pub(crate) fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { let cstr = CStr::from_bytes_until_nul(bytes).ok()?; let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len()); Some((cstr, other)) } /// See . -pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { +pub(crate) fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { (bytes.len() >= N).then(|| { let (head, tail) = bytes.split_at(N); (head.try_into().unwrap(), tail) diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 1dd4563514..17764f78d1 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -13,9 +13,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; pin_project! { /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough - pub struct ChainRW { + pub(crate) struct ChainRW { #[pin] - pub inner: T, + pub(crate) inner: T, buf: BytesMut, } } @@ -60,7 +60,7 @@ const HEADER: [u8; 12] = [ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, ]; -pub async fn read_proxy_protocol( +pub(crate) async fn read_proxy_protocol( mut read: T, ) -> std::io::Result<(ChainRW, Option)> { let mut buf = BytesMut::with_capacity(128); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index aa1025a29f..ff199ac701 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,12 +1,12 @@ #[cfg(test)] mod tests; -pub mod connect_compute; +pub(crate) mod connect_compute; mod copy_bidirectional; -pub mod handshake; -pub mod passthrough; -pub mod retry; -pub mod wake_compute; +pub(crate) mod handshake; +pub(crate) mod passthrough; +pub(crate) mod retry; +pub(crate) mod wake_compute; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; @@ -170,21 +170,21 @@ pub async fn task_main( Ok(()) } -pub enum ClientMode { +pub(crate) enum ClientMode { Tcp, Websockets { hostname: Option }, } /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - pub fn allow_cleartext(&self) -> bool { + pub(crate) fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub(crate) fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -213,7 +213,7 @@ impl ClientMode { // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation, // we cannot be sure the client even understands our error message // 3. PrepareClient: The client disconnected, so we can't tell them anyway... -pub enum ClientRequestError { +pub(crate) enum ClientRequestError { #[error("{0}")] Cancellation(#[from] cancellation::CancelError), #[error("{0}")] @@ -238,7 +238,7 @@ impl ReportableError for ClientRequestError { } } -pub async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, ctx: &RequestMonitoring, cancellation_handler: Arc, @@ -340,9 +340,9 @@ pub async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, - req: request_gauge, - conn: conn_gauge, - cancel: session, + _req: request_gauge, + _conn: conn_gauge, + _cancel: session, })) } @@ -377,20 +377,20 @@ async fn prepare_client_connection

( } #[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct NeonOptions(Vec<(SmolStr, SmolStr)>); +pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>); impl NeonOptions { - pub fn parse_params(params: &StartupMessageParams) -> Self { + pub(crate) fn parse_params(params: &StartupMessageParams) -> Self { params .options_raw() .map(Self::parse_from_iter) .unwrap_or_default() } - pub fn parse_options_raw(options: &str) -> Self { + pub(crate) fn parse_options_raw(options: &str) -> Self { Self::parse_from_iter(StartupMessageParams::parse_options_raw(options)) } - pub fn is_ephemeral(&self) -> bool { + pub(crate) fn is_ephemeral(&self) -> bool { // Currently, neon endpoint options are all reserved for ephemeral endpoints. !self.0.is_empty() } @@ -404,7 +404,7 @@ impl NeonOptions { Self(options) } - pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { + pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable std::iter::once(prefix) @@ -415,7 +415,7 @@ impl NeonOptions { /// DeepObject format /// `paramName[prop1]=value1¶mName[prop2]=value2&...` - pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { + pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { self.0 .iter() .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone())) @@ -423,7 +423,7 @@ impl NeonOptions { } } -pub fn neon_option(bytes: &str) -> Option<(&str, &str)> { +pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> { static RE: OnceCell = OnceCell::new(); let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap()); diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 6305dc204e..613548d4a0 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -25,7 +25,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -41,7 +41,7 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { } #[async_trait] -pub trait ConnectMechanism { +pub(crate) trait ConnectMechanism { type Connection; type ConnectError: ReportableError; type Error: From; @@ -56,7 +56,7 @@ pub trait ConnectMechanism { } #[async_trait] -pub trait ComputeConnectBackend { +pub(crate) trait ComputeConnectBackend { async fn wake_compute( &self, ctx: &RequestMonitoring, @@ -65,12 +65,12 @@ pub trait ComputeConnectBackend { fn get_keys(&self) -> &ComputeCredentialKeys; } -pub struct TcpMechanism<'a> { +pub(crate) struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. - pub params: &'a StartupMessageParams, + pub(crate) params: &'a StartupMessageParams, /// connect_to_compute concurrency lock - pub locks: &'static ApiLocks, + pub(crate) locks: &'static ApiLocks, } #[async_trait] @@ -98,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub(crate) async fn connect_to_compute( ctx: &RequestMonitoring, mechanism: &M, user_info: &B, diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index f8c8e8bc4b..4ebda013ac 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -14,7 +14,7 @@ enum TransferState { } #[derive(Debug)] -pub enum ErrorDirection { +pub(crate) enum ErrorDirection { Read(io::Error), Write(io::Error), } diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 27a72f8072..5996b11c11 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -18,7 +18,7 @@ use crate::{ }; #[derive(Error, Debug)] -pub enum HandshakeError { +pub(crate) enum HandshakeError { #[error("data is sent before server replied with EncryptionResponse")] EarlyData, @@ -57,7 +57,7 @@ impl ReportableError for HandshakeError { } } -pub enum HandshakeData { +pub(crate) enum HandshakeData { Startup(PqStream>, StartupMessageParams), Cancel(CancelKeyData), } @@ -67,7 +67,7 @@ pub enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub async fn handshake( +pub(crate) async fn handshake( ctx: &RequestMonitoring, stream: S, mut tls: Option<&TlsConfig>, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 9942fac383..c17108de0a 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -14,7 +14,7 @@ use super::copy_bidirectional::ErrorSource; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] -pub async fn proxy_pass( +pub(crate) async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, @@ -57,18 +57,18 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { - pub client: Stream, - pub compute: PostgresConnection, - pub aux: MetricsAuxInfo, +pub(crate) struct ProxyPassthrough { + pub(crate) client: Stream, + pub(crate) compute: PostgresConnection, + pub(crate) aux: MetricsAuxInfo, - pub req: NumConnectionRequestsGuard<'static>, - pub conn: NumClientConnectionsGuard<'static>, - pub cancel: cancellation::Session

, + pub(crate) _req: NumConnectionRequestsGuard<'static>, + pub(crate) _conn: NumClientConnectionsGuard<'static>, + pub(crate) _cancel: cancellation::Session

, } impl ProxyPassthrough { - pub async fn proxy_pass(self) -> Result<(), ErrorSource> { + pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { tracing::error!(?err, "could not cancel the query in the database"); diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 644b183a91..15895d37e6 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -2,18 +2,18 @@ use crate::{compute, config::RetryConfig}; use std::{error::Error, io}; use tokio::time; -pub trait CouldRetry { +pub(crate) trait CouldRetry { /// Returns true if the error could be retried fn could_retry(&self) -> bool; } -pub trait ShouldRetryWakeCompute { +pub(crate) trait ShouldRetryWakeCompute { /// Returns true if we need to invalidate the cache for this node. /// If false, we can continue retrying with the current node cache. fn should_retry_wake_compute(&self) -> bool; } -pub fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool { +pub(crate) fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool { num_retries < config.max_retries && err.could_retry() } @@ -101,7 +101,7 @@ impl ShouldRetryWakeCompute for compute::ConnectionError { } } -pub fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration { +pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration { config .base_delay .mul_f64(config.backoff_factor.powi((num_retries as i32) - 1)) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 21c0641a7f..4264dbae0f 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -11,14 +11,14 @@ use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, }; use crate::config::{CertResolver, RetryConfig}; -use crate::console::caches::NodeInfoCache; use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; +use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeInfoCache}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; -use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; +use crate::{sasl, scram, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; use async_trait::async_trait; +use http::StatusCode; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::pki_types; @@ -491,7 +491,7 @@ impl TestBackend for TestConnectMechanism { ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console(ConsoleError { - http_status_code: http::StatusCode::BAD_REQUEST, + http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: None, }); @@ -500,7 +500,7 @@ impl TestBackend for TestConnectMechanism { } ConnectAction::WakeRetry => { let err = console::errors::ApiError::Console(ConsoleError { - http_status_code: http::StatusCode::BAD_REQUEST, + http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: Some(Status { code: "error".into(), @@ -525,9 +525,6 @@ impl TestBackend for TestConnectMechanism { { unimplemented!("not used in tests") } - fn get_role_secret(&self) -> Result { - unimplemented!("not used in tests") - } } fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { @@ -547,8 +544,8 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> auth::BackendType<'static, ComputeCredentials, &()> { - let user_info = auth::BackendType::Console( +) -> auth::Backend<'static, ComputeCredentials, &()> { + let user_info = auth::Backend::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), ComputeCredentials { info: ComputeUserInfo { diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 71f07f4682..33a2162bc7 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -102,7 +102,7 @@ async fn proxy_mitm( } /// taken from tokio-postgres -pub async fn connect_tls(mut stream: S, tls: T) -> T::Stream +pub(crate) async fn connect_tls(mut stream: S, tls: T) -> T::Stream where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 5b06e8f054..9b8ac6d29d 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -12,7 +12,7 @@ use tracing::{error, info, warn}; use super::connect_compute::ComputeConnectBackend; -pub async fn wake_compute( +pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestMonitoring, api: &B, diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 222cd431d2..6e38f89458 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -1,10 +1,14 @@ +mod leaky_bucket; mod limit_algorithm; mod limiter; -pub use limit_algorithm::{ - aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, -}; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; -mod leaky_bucket; -pub use leaky_bucket::{ - EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState, + +#[cfg(test)] +pub(crate) use limit_algorithm::aimd::Aimd; + +pub(crate) use limit_algorithm::{ + DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; +pub(crate) use limiter::GlobalRateLimiter; + +pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter}; +pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index f184e18f4c..bf4d85f2e4 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -8,6 +8,7 @@ use dashmap::DashMap; use rand::{thread_rng, Rng}; use tokio::time::Instant; use tracing::info; +use utils::leaky_bucket::LeakyBucketState; use crate::intern::EndpointIdInt; @@ -16,7 +17,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: DashMap, - config: LeakyBucketConfig, + config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -29,25 +30,25 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config, + config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, n: u32) -> bool { let now = Instant::now(); if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } - let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState { - time: now, - filled: 0.0, - }); + let mut entry = self + .map + .entry(key) + .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.check(&self.config, now, n as f64) + entry.add_tokens(&self.config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { @@ -59,7 +60,7 @@ impl LeakyBucketRateLimiter { let shard = thread_rng().gen_range(0..n); self.map.shards()[shard] .write() - .retain(|_, value| !value.get_mut().update(&self.config, now)); + .retain(|_, value| !value.get().bucket_is_empty(now)); } } @@ -68,53 +69,18 @@ pub struct LeakyBucketConfig { pub max: f64, } -pub struct LeakyBucketState { - filled: f64, - time: Instant, -} - +#[cfg(test)] impl LeakyBucketConfig { - pub fn new(rps: f64, max: f64) -> Self { + pub(crate) fn new(rps: f64, max: f64) -> Self { assert!(rps > 0.0, "rps must be positive"); assert!(max > 0.0, "max must be positive"); Self { rps, max } } } -impl LeakyBucketState { - pub fn new() -> Self { - Self { - filled: 0.0, - time: Instant::now(), - } - } - - /// updates the timer and returns true if the bucket is empty - fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool { - let drain = now.duration_since(self.time); - let drain = drain.as_secs_f64() * info.rps; - - self.filled = (self.filled - drain).clamp(0.0, info.max); - self.time = now; - - self.filled == 0.0 - } - - pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool { - self.update(info, now); - - if self.filled + n > info.max { - return false; - } - self.filled += n; - - true - } -} - -impl Default for LeakyBucketState { - fn default() -> Self { - Self::new() +impl From for utils::leaky_bucket::LeakyBucketConfig { + fn from(config: LeakyBucketConfig) -> Self { + utils::leaky_bucket::LeakyBucketConfig::new(config.rps, config.max) } } @@ -124,48 +90,55 @@ mod tests { use std::time::Duration; use tokio::time::Instant; + use utils::leaky_bucket::LeakyBucketState; - use super::{LeakyBucketConfig, LeakyBucketState}; + use super::LeakyBucketConfig; #[tokio::test(start_paused = true)] async fn check() { - let info = LeakyBucketConfig::new(500.0, 2000.0); - let mut bucket = LeakyBucketState::new(); + let config: utils::leaky_bucket::LeakyBucketConfig = + LeakyBucketConfig::new(500.0, 2000.0).into(); + assert_eq!(config.cost, Duration::from_millis(2)); + assert_eq!(config.bucket_width, Duration::from_secs(4)); + + let mut bucket = LeakyBucketState { + empty_at: Instant::now(), + }; // should work for 2000 requests this second for _ in 0..2000 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap(); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); - assert_eq!(bucket.filled, 2000.0); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); + assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width); // in 1ms we should drain 0.5 tokens. // make sure we don't lose any tokens tokio::time::advance(Duration::from_millis(1)).await; - assert!(!bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); tokio::time::advance(Duration::from_millis(1)).await; - assert!(bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap(); // in 10ms we should drain 5 tokens tokio::time::advance(Duration::from_millis(10)).await; for _ in 0..5 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap(); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); // in 10s we should drain 5000 tokens // but cap is only 2000 tokio::time::advance(Duration::from_secs(10)).await; for _ in 0..2000 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap(); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err(); // should sustain 500rps for _ in 0..2000 { tokio::time::advance(Duration::from_millis(10)).await; for _ in 0..5 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + bucket.add_tokens(&config, Instant::now(), 1.0).unwrap(); } } } diff --git a/proxy/src/rate_limiter/limit_algorithm.rs b/proxy/src/rate_limiter/limit_algorithm.rs index bc16837f65..25607b7e10 100644 --- a/proxy/src/rate_limiter/limit_algorithm.rs +++ b/proxy/src/rate_limiter/limit_algorithm.rs @@ -8,13 +8,13 @@ use tokio::{ use self::aimd::Aimd; -pub mod aimd; +pub(crate) mod aimd; /// Whether a job succeeded or failed as a result of congestion/overload. /// /// Errors not considered to be caused by overload should be ignored. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Outcome { +pub(crate) enum Outcome { /// The job succeeded, or failed in a way unrelated to overload. Success, /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal @@ -23,14 +23,14 @@ pub enum Outcome { } /// An algorithm for controlling a concurrency limit. -pub trait LimitAlgorithm: Send + Sync + 'static { +pub(crate) trait LimitAlgorithm: Send + Sync + 'static { /// Update the concurrency limit in response to a new job completion. fn update(&self, old_limit: usize, sample: Sample) -> usize; } /// The result of a job (or jobs), including the [`Outcome`] (loss) and latency (delay). #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub struct Sample { +pub(crate) struct Sample { pub(crate) latency: Duration, /// Jobs in flight when the sample was taken. pub(crate) in_flight: usize, @@ -39,7 +39,7 @@ pub struct Sample { #[derive(Clone, Copy, Debug, Default, serde::Deserialize, PartialEq)] #[serde(rename_all = "snake_case")] -pub enum RateLimitAlgorithm { +pub(crate) enum RateLimitAlgorithm { #[default] Fixed, Aimd { @@ -48,7 +48,7 @@ pub enum RateLimitAlgorithm { }, } -pub struct Fixed; +pub(crate) struct Fixed; impl LimitAlgorithm for Fixed { fn update(&self, old_limit: usize, _sample: Sample) -> usize { @@ -59,12 +59,12 @@ impl LimitAlgorithm for Fixed { #[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)] pub struct RateLimiterConfig { #[serde(flatten)] - pub algorithm: RateLimitAlgorithm, - pub initial_limit: usize, + pub(crate) algorithm: RateLimitAlgorithm, + pub(crate) initial_limit: usize, } impl RateLimiterConfig { - pub fn create_rate_limit_algorithm(self) -> Box { + pub(crate) fn create_rate_limit_algorithm(self) -> Box { match self.algorithm { RateLimitAlgorithm::Fixed => Box::new(Fixed), RateLimitAlgorithm::Aimd { conf } => Box::new(conf), @@ -72,7 +72,7 @@ impl RateLimiterConfig { } } -pub struct LimiterInner { +pub(crate) struct LimiterInner { alg: Box, available: usize, limit: usize, @@ -114,7 +114,7 @@ impl LimiterInner { /// /// The limit will be automatically adjusted based on observed latency (delay) and/or failures /// caused by overload (loss). -pub struct DynamicLimiter { +pub(crate) struct DynamicLimiter { config: RateLimiterConfig, inner: Mutex, // to notify when a token is available @@ -124,7 +124,7 @@ pub struct DynamicLimiter { /// A concurrency token, required to run a job. /// /// Release the token back to the [`DynamicLimiter`] after the job is complete. -pub struct Token { +pub(crate) struct Token { start: Instant, limiter: Option>, } @@ -133,14 +133,14 @@ pub struct Token { /// /// Not guaranteed to be consistent under high concurrency. #[derive(Debug, Clone, Copy)] -pub struct LimiterState { +#[cfg(test)] +struct LimiterState { limit: usize, - in_flight: usize, } impl DynamicLimiter { /// Create a limiter with a given limit control algorithm. - pub fn new(config: RateLimiterConfig) -> Arc { + pub(crate) fn new(config: RateLimiterConfig) -> Arc { let ready = Notify::new(); ready.notify_one(); @@ -157,7 +157,10 @@ impl DynamicLimiter { } /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available. - pub async fn acquire_timeout(self: &Arc, duration: Duration) -> Result { + pub(crate) async fn acquire_timeout( + self: &Arc, + duration: Duration, + ) -> Result { tokio::time::timeout(duration, self.acquire()).await? } @@ -208,12 +211,10 @@ impl DynamicLimiter { } /// The current state of the limiter. - pub fn state(&self) -> LimiterState { + #[cfg(test)] + fn state(&self) -> LimiterState { let inner = self.inner.lock(); - LimiterState { - limit: inner.limit, - in_flight: inner.in_flight, - } + LimiterState { limit: inner.limit } } } @@ -224,22 +225,22 @@ impl Token { limiter: Some(limiter), } } - pub fn disabled() -> Self { + pub(crate) fn disabled() -> Self { Self { start: Instant::now(), limiter: None, } } - pub fn is_disabled(&self) -> bool { + pub(crate) fn is_disabled(&self) -> bool { self.limiter.is_none() } - pub fn release(mut self, outcome: Outcome) { + pub(crate) fn release(mut self, outcome: Outcome) { self.release_mut(Some(outcome)); } - pub fn release_mut(&mut self, outcome: Option) { + pub(crate) fn release_mut(&mut self, outcome: Option) { if let Some(limiter) = self.limiter.take() { limiter.release_inner(self.start, outcome); } @@ -252,13 +253,10 @@ impl Drop for Token { } } +#[cfg(test)] impl LimiterState { /// The current concurrency limit. - pub fn limit(&self) -> usize { + fn limit(self) -> usize { self.limit } - /// The number of jobs in flight. - pub fn in_flight(&self) -> usize { - self.in_flight - } } diff --git a/proxy/src/rate_limiter/limit_algorithm/aimd.rs b/proxy/src/rate_limiter/limit_algorithm/aimd.rs index d669492fa6..86b56e38fb 100644 --- a/proxy/src/rate_limiter/limit_algorithm/aimd.rs +++ b/proxy/src/rate_limiter/limit_algorithm/aimd.rs @@ -10,17 +10,17 @@ use super::{LimitAlgorithm, Outcome, Sample}; /// /// Reduces available concurrency by a factor when load-based errors are detected. #[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)] -pub struct Aimd { +pub(crate) struct Aimd { /// Minimum limit for AIMD algorithm. - pub min: usize, + pub(crate) min: usize, /// Maximum limit for AIMD algorithm. - pub max: usize, + pub(crate) max: usize, /// Decrease AIMD decrease by value in case of error. - pub dec: f32, + pub(crate) dec: f32, /// Increase AIMD increase by value in case of success. - pub inc: usize, + pub(crate) inc: usize, /// A threshold below which the limit won't be increased. - pub utilisation: f32, + pub(crate) utilisation: f32, } impl LimitAlgorithm for Aimd { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 5db4efed37..be529f174d 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -17,13 +17,13 @@ use tracing::info; use crate::intern::EndpointIdInt; -pub struct GlobalRateLimiter { +pub(crate) struct GlobalRateLimiter { data: Vec, info: Vec, } impl GlobalRateLimiter { - pub fn new(info: Vec) -> Self { + pub(crate) fn new(info: Vec) -> Self { Self { data: vec![ RateBucket { @@ -37,7 +37,7 @@ impl GlobalRateLimiter { } /// Check that number of connections is below `max_rps` rps. - pub fn check(&mut self) -> bool { + pub(crate) fn check(&mut self) -> bool { let now = Instant::now(); let should_allow_request = self @@ -96,9 +96,9 @@ impl RateBucket { #[derive(Clone, Copy, PartialEq)] pub struct RateBucketInfo { - pub interval: Duration, + pub(crate) interval: Duration, // requests per interval - pub max_rpi: u32, + pub(crate) max_rpi: u32, } impl std::fmt::Display for RateBucketInfo { @@ -192,7 +192,7 @@ impl BucketRateLimiter { } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, n: u32) -> bool { // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map. // worst case memory usage is about: // = 2 * 2048 * 64 * (48B + 72B) @@ -228,7 +228,7 @@ impl BucketRateLimiter { /// Clean the map. Simple strategy: remove all entries in a random shard. /// At worst, we'll double the effective max_rps during the cleanup. /// But that way deletion does not aquire mutex on each entry access. - pub fn do_gc(&self) { + pub(crate) fn do_gc(&self) { info!( "cleaning up bucket rate limiter, current size = {}", self.map.len() diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index c9a946fa4a..95bdfc0965 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -109,7 +109,7 @@ impl RedisPublisherClient { let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?; Ok(()) } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { + pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> { match self.client.connect().await { Ok(()) => {} Err(e) => { diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index c78ee166f1..7d222e2dec 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -81,7 +81,7 @@ impl ConnectionWithCredentialsProvider { redis::cmd("PING").query_async(con).await } - pub async fn connect(&mut self) -> anyhow::Result<()> { + pub(crate) async fn connect(&mut self) -> anyhow::Result<()> { let _guard = self.mutex.lock().await; if let Some(con) = self.con.as_mut() { match Self::ping(con).await { @@ -149,7 +149,7 @@ impl ConnectionWithCredentialsProvider { // PubSub does not support credentials refresh. // Requires manual reconnection every 12h. - pub async fn get_async_pubsub(&self) -> anyhow::Result { + pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result { Ok(self.get_client().await?.get_async_pubsub().await?) } @@ -187,7 +187,10 @@ impl ConnectionWithCredentialsProvider { } /// Sends an already encoded (packed) command into the TCP socket and /// reads the single response from it. - pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { + pub(crate) async fn send_packed_command( + &mut self, + cmd: &redis::Cmd, + ) -> RedisResult { // Clone connection to avoid having to lock the ArcSwap in write mode let con = self.con.as_mut().ok_or(redis::RedisError::from(( redis::ErrorKind::IoError, @@ -199,7 +202,7 @@ impl ConnectionWithCredentialsProvider { /// Sends multiple already encoded (packed) command into the TCP socket /// and reads `count` responses from it. This is used to implement /// pipelining. - pub async fn send_packed_commands( + pub(crate) async fn send_packed_commands( &mut self, cmd: &redis::Pipeline, offset: usize, diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs index eded8250af..d118c8f412 100644 --- a/proxy/src/redis/elasticache.rs +++ b/proxy/src/redis/elasticache.rs @@ -51,7 +51,7 @@ impl CredentialsProvider { credentials_provider, } } - pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { + pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { let aws_credentials = self .credentials_provider .provide_credentials() diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 31c0e62c2c..36a3443603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -58,9 +58,9 @@ pub(crate) struct PasswordUpdate { } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub(crate) struct CancelSession { - pub region_id: Option, - pub cancel_key_data: CancelKeyData, - pub session_id: Uuid, + pub(crate) region_id: Option, + pub(crate) cancel_key_data: CancelKeyData, + pub(crate) session_id: Uuid, } fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result @@ -89,7 +89,7 @@ impl Clone for MessageHandler { } impl MessageHandler { - pub fn new( + pub(crate) fn new( cache: Arc, cancellation_handler: Arc>, region_id: String, @@ -100,10 +100,10 @@ impl MessageHandler { region_id, } } - pub async fn increment_active_listeners(&self) { + pub(crate) async fn increment_active_listeners(&self) { self.cache.increment_active_listeners().await; } - pub async fn decrement_active_listeners(&self) { + pub(crate) async fn decrement_active_listeners(&self) { self.cache.decrement_active_listeners().await; } #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 60207fc824..0a36694359 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -14,13 +14,13 @@ use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; -pub use channel_binding::ChannelBinding; -pub use messages::FirstMessage; -pub use stream::{Outcome, SaslStream}; +pub(crate) use channel_binding::ChannelBinding; +pub(crate) use messages::FirstMessage; +pub(crate) use stream::{Outcome, SaslStream}; /// Fine-grained auth errors help in writing tests. #[derive(Error, Debug)] -pub enum Error { +pub(crate) enum Error { #[error("Channel binding failed: {0}")] ChannelBindingFailed(&'static str), @@ -64,11 +64,11 @@ impl ReportableError for Error { } /// A convenient result type for SASL exchange. -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; /// A result of one SASL exchange. #[must_use] -pub enum Step { +pub(crate) enum Step { /// We should continue exchanging messages. Continue(T, String), /// The client has been authenticated successfully. @@ -78,7 +78,7 @@ pub enum Step { } /// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. -pub trait Mechanism: Sized { +pub(crate) trait Mechanism: Sized { /// What's produced as a result of successful authentication. type Output; diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs index 6e2d3057ce..fdd011448e 100644 --- a/proxy/src/sasl/channel_binding.rs +++ b/proxy/src/sasl/channel_binding.rs @@ -2,7 +2,7 @@ /// Channel binding flag (possibly with params). #[derive(Debug, PartialEq, Eq)] -pub enum ChannelBinding { +pub(crate) enum ChannelBinding { /// Client doesn't support channel binding. NotSupportedClient, /// Client thinks server doesn't support channel binding. @@ -12,7 +12,10 @@ pub enum ChannelBinding { } impl ChannelBinding { - pub fn and_then(self, f: impl FnOnce(T) -> Result) -> Result, E> { + pub(crate) fn and_then( + self, + f: impl FnOnce(T) -> Result, + ) -> Result, E> { Ok(match self { Self::NotSupportedClient => ChannelBinding::NotSupportedClient, Self::NotSupportedServer => ChannelBinding::NotSupportedServer, @@ -23,7 +26,7 @@ impl ChannelBinding { impl<'a> ChannelBinding<&'a str> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { + pub(crate) fn parse(input: &'a str) -> Option { Some(match input { "n" => Self::NotSupportedClient, "y" => Self::NotSupportedServer, @@ -34,7 +37,7 @@ impl<'a> ChannelBinding<&'a str> { impl ChannelBinding { /// Encode channel binding data as base64 for subsequent checks. - pub fn encode<'a, E>( + pub(crate) fn encode<'a, E>( &self, get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>, ) -> Result, E> { diff --git a/proxy/src/sasl/messages.rs b/proxy/src/sasl/messages.rs index 2b5ae1785d..6c9a42b2db 100644 --- a/proxy/src/sasl/messages.rs +++ b/proxy/src/sasl/messages.rs @@ -5,16 +5,16 @@ use pq_proto::{BeAuthenticationSaslMessage, BeMessage}; /// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage). #[derive(Debug)] -pub struct FirstMessage<'a> { +pub(crate) struct FirstMessage<'a> { /// Authentication method, e.g. `"SCRAM-SHA-256"`. - pub method: &'a str, + pub(crate) method: &'a str, /// Initial client message. - pub message: &'a str, + pub(crate) message: &'a str, } impl<'a> FirstMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(bytes: &'a [u8]) -> Option { + pub(crate) fn parse(bytes: &'a [u8]) -> Option { let (method_cstr, tail) = split_cstr(bytes)?; let method = method_cstr.to_str().ok()?; diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index 9115b0f61a..b6becd28e1 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; /// Abstracts away all peculiarities of the libpq's protocol. -pub struct SaslStream<'a, S> { +pub(crate) struct SaslStream<'a, S> { /// The underlying stream. stream: &'a mut PqStream, /// Current password message we received from client. @@ -17,7 +17,7 @@ pub struct SaslStream<'a, S> { } impl<'a, S> SaslStream<'a, S> { - pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + pub(crate) fn new(stream: &'a mut PqStream, first: &'a str) -> Self { Self { stream, current: bytes::Bytes::new(), @@ -53,7 +53,7 @@ impl SaslStream<'_, S> { /// It's much easier to match on those two variants /// than to peek into a noisy protocol error type. #[must_use = "caller must explicitly check for success"] -pub enum Outcome { +pub(crate) enum Outcome { /// Authentication succeeded and produced some value. Success(R), /// Authentication failed (reason attached). @@ -63,7 +63,7 @@ pub enum Outcome { impl SaslStream<'_, S> { /// Perform SASL message exchange according to the underlying algorithm /// until user is either authenticated or denied access. - pub async fn authenticate( + pub(crate) async fn authenticate( mut self, mut mechanism: M, ) -> super::Result> { diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 145e727a74..d058f1c3f8 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -15,9 +15,9 @@ mod secret; mod signature; pub mod threadpool; -pub use exchange::{exchange, Exchange}; -pub use key::ScramKey; -pub use secret::ServerSecret; +pub(crate) use exchange::{exchange, Exchange}; +pub(crate) use key::ScramKey; +pub(crate) use secret::ServerSecret; use hmac::{Hmac, Mac}; use sha2::{Digest, Sha256}; @@ -26,8 +26,8 @@ const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; /// A list of supported SCRAM methods. -pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256]; -pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; +pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256]; +pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; /// Decode base64 into array without any heap allocations fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs index 944bb3c83e..64ee0135e1 100644 --- a/proxy/src/scram/countmin.rs +++ b/proxy/src/scram/countmin.rs @@ -2,7 +2,7 @@ use std::hash::Hash; /// estimator of hash jobs per second. /// -pub struct CountMinSketch { +pub(crate) struct CountMinSketch { // one for each depth hashers: Vec, width: usize, @@ -20,7 +20,7 @@ impl CountMinSketch { /// actual <= estimate /// estimate <= actual + ε * N with probability 1 - δ /// where N is the cardinality of the stream - pub fn with_params(epsilon: f64, delta: f64) -> Self { + pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self { CountMinSketch::new( (std::f64::consts::E / epsilon).ceil() as usize, (1.0_f64 / delta).ln().ceil() as usize, @@ -49,7 +49,7 @@ impl CountMinSketch { } } - pub fn inc_and_return(&mut self, t: &T, x: u32) -> u32 { + pub(crate) fn inc_and_return(&mut self, t: &T, x: u32) -> u32 { let mut min = u32::MAX; for row in 0..self.depth { let col = (self.hashers[row].hash_one(t) as usize) % self.width; @@ -61,7 +61,7 @@ impl CountMinSketch { min } - pub fn reset(&mut self) { + pub(crate) fn reset(&mut self) { self.buckets.clear(); self.buckets.resize(self.width * self.depth, 0); } @@ -83,10 +83,10 @@ mod tests { let mut ids = vec![]; for _ in 0..n { - // number of insert operations - let n = rng.gen_range(1..100); // number to insert at once - let m = rng.gen_range(1..4096); + let n = rng.gen_range(1..4096); + // number of insert operations + let m = rng.gen_range(1..100); let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid(); ids.push((id, n, m)); @@ -102,17 +102,11 @@ mod tests { let mut ids2 = ids.clone(); while !ids2.is_empty() { ids2.shuffle(&mut rng); - - let mut i = 0; - while i < ids2.len() { - sketch.inc_and_return(&ids2[i].0, ids2[i].1); - ids2[i].2 -= 1; - if ids2[i].2 == 0 { - ids2.remove(i); - } else { - i += 1; - } - } + ids2.retain_mut(|id| { + sketch.inc_and_return(&id.0, id.1); + id.2 -= 1; + id.2 > 0 + }); } let mut within_p = 0; @@ -144,8 +138,8 @@ mod tests { // probably numbers are too small to truly represent the probabilities. assert_eq!(eval_precision(100, 4096.0, 0.90), 100); assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000); - assert_eq!(eval_precision(100, 4096.0, 0.1), 98); - assert_eq!(eval_precision(1000, 4096.0, 0.1), 991); + assert_eq!(eval_precision(100, 4096.0, 0.1), 96); + assert_eq!(eval_precision(1000, 4096.0, 0.1), 988); } // returns memory usage in bytes, and the time complexity per insert. diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index f2494379a5..786cbcaa19 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -56,14 +56,14 @@ enum ExchangeState { } /// Server's side of SCRAM auth algorithm. -pub struct Exchange<'a> { +pub(crate) struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, tls_server_end_point: config::TlsServerEndPoint, } impl<'a> Exchange<'a> { - pub fn new( + pub(crate) fn new( secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], tls_server_end_point: config::TlsServerEndPoint, @@ -86,8 +86,7 @@ async fn derive_client_key( ) -> ScramKey { let salted_password = pool .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await - .expect("job should not be cancelled"); + .await; let make_key = |name| { let key = Hmac::::new_from_slice(&salted_password) @@ -101,7 +100,7 @@ async fn derive_client_key( make_key(b"Client Key").into() } -pub async fn exchange( +pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, secret: &ServerSecret, diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index 32a3dbd203..fe55ff493b 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -3,14 +3,14 @@ use subtle::ConstantTimeEq; /// Faithfully taken from PostgreSQL. -pub const SCRAM_KEY_LEN: usize = 32; +pub(crate) const SCRAM_KEY_LEN: usize = 32; /// One of the keys derived from the user's password. /// We use the same structure for all keys, i.e. /// `ClientKey`, `StoredKey`, and `ServerKey`. #[derive(Clone, Default, Eq, Debug)] #[repr(transparent)] -pub struct ScramKey { +pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } @@ -27,11 +27,11 @@ impl ConstantTimeEq for ScramKey { } impl ScramKey { - pub fn sha256(&self) -> Self { + pub(crate) fn sha256(&self) -> Self { super::sha256([self.as_ref()]).into() } - pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { + pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } } diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index 54157e450d..fd9e77764c 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -8,7 +8,7 @@ use std::fmt; use std::ops::Range; /// Faithfully taken from PostgreSQL. -pub const SCRAM_RAW_NONCE_LEN: usize = 18; +pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18; /// Although we ignore all extensions, we still have to validate the message. fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option<()> { @@ -27,18 +27,18 @@ fn validate_sasl_extensions<'a>(parts: impl Iterator) -> Option< } #[derive(Debug)] -pub struct ClientFirstMessage<'a> { +pub(crate) struct ClientFirstMessage<'a> { /// `client-first-message-bare`. - pub bare: &'a str, + pub(crate) bare: &'a str, /// Channel binding mode. - pub cbind_flag: ChannelBinding<&'a str>, + pub(crate) cbind_flag: ChannelBinding<&'a str>, /// Client nonce. - pub nonce: &'a str, + pub(crate) nonce: &'a str, } impl<'a> ClientFirstMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { + pub(crate) fn parse(input: &'a str) -> Option { let mut parts = input.split(','); let cbind_flag = ChannelBinding::parse(parts.next()?)?; @@ -77,7 +77,7 @@ impl<'a> ClientFirstMessage<'a> { } /// Build a response to [`ClientFirstMessage`]. - pub fn build_server_first_message( + pub(crate) fn build_server_first_message( &self, nonce: &[u8; SCRAM_RAW_NONCE_LEN], salt_base64: &str, @@ -101,20 +101,20 @@ impl<'a> ClientFirstMessage<'a> { } #[derive(Debug)] -pub struct ClientFinalMessage<'a> { +pub(crate) struct ClientFinalMessage<'a> { /// `client-final-message-without-proof`. - pub without_proof: &'a str, + pub(crate) without_proof: &'a str, /// Channel binding data (base64). - pub channel_binding: &'a str, + pub(crate) channel_binding: &'a str, /// Combined client & server nonce. - pub nonce: &'a str, + pub(crate) nonce: &'a str, /// Client auth proof. - pub proof: [u8; SCRAM_KEY_LEN], + pub(crate) proof: [u8; SCRAM_KEY_LEN], } impl<'a> ClientFinalMessage<'a> { // NB: FromStr doesn't work with lifetimes - pub fn parse(input: &'a str) -> Option { + pub(crate) fn parse(input: &'a str) -> Option { let (without_proof, proof) = input.rsplit_once(',')?; let mut parts = without_proof.split(','); @@ -135,7 +135,7 @@ impl<'a> ClientFinalMessage<'a> { } /// Build a response to [`ClientFinalMessage`]. - pub fn build_server_final_message( + pub(crate) fn build_server_final_message( &self, signature_builder: SignatureBuilder<'_>, server_key: &ScramKey, @@ -153,7 +153,7 @@ impl<'a> ClientFinalMessage<'a> { /// We need to keep a convenient representation of this /// message for the next authentication step. -pub struct OwnedServerFirstMessage { +pub(crate) struct OwnedServerFirstMessage { /// Owned `server-first-message`. message: String, /// Slice into `message`. @@ -163,13 +163,13 @@ pub struct OwnedServerFirstMessage { impl OwnedServerFirstMessage { /// Extract combined nonce from the message. #[inline(always)] - pub fn nonce(&self) -> &str { + pub(crate) fn nonce(&self) -> &str { &self.message[self.nonce.clone()] } /// Get reference to a text representation of the message. #[inline(always)] - pub fn as_str(&self) -> &str { + pub(crate) fn as_str(&self) -> &str { &self.message } } diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index f690cc7738..4cf76c8452 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -4,7 +4,7 @@ use hmac::{ }; use sha2::Sha256; -pub struct Pbkdf2 { +pub(crate) struct Pbkdf2 { hmac: Hmac, prev: GenericArray, hi: GenericArray, @@ -13,7 +13,7 @@ pub struct Pbkdf2 { // inspired from impl Pbkdf2 { - pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { let hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); @@ -33,11 +33,11 @@ impl Pbkdf2 { } } - pub fn cost(&self) -> u32 { + pub(crate) fn cost(&self) -> u32 { (self.iterations).clamp(0, 4096) } - pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { let Self { hmac, prev, @@ -75,7 +75,7 @@ mod tests { let salt = b"sodium chloride"; let pass = b"Ne0n_!5_50_C007"; - let mut job = Pbkdf2::start(pass, salt, 600000); + let mut job = Pbkdf2::start(pass, salt, 60000); let hash = loop { let std::task::Poll::Ready(hash) = job.turn() else { continue; @@ -83,7 +83,7 @@ mod tests { break hash; }; - let expected = pbkdf2_hmac_array::(pass, salt, 600000); + let expected = pbkdf2_hmac_array::(pass, salt, 60000); assert_eq!(hash, expected); } } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index a08cb943c3..8c6a08d432 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -8,22 +8,22 @@ use super::key::ScramKey; /// Server secret is produced from user's password, /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] -pub struct ServerSecret { +pub(crate) struct ServerSecret { /// Number of iterations for `PBKDF2` function. - pub iterations: u32, + pub(crate) iterations: u32, /// Salt used to hash user's password. - pub salt_base64: String, + pub(crate) salt_base64: String, /// Hashed `ClientKey`. - pub stored_key: ScramKey, + pub(crate) stored_key: ScramKey, /// Used by client to verify server's signature. - pub server_key: ScramKey, + pub(crate) server_key: ScramKey, /// Should auth fail no matter what? /// This is exactly the case for mocked secrets. - pub doomed: bool, + pub(crate) doomed: bool, } impl ServerSecret { - pub fn parse(input: &str) -> Option { + pub(crate) fn parse(input: &str) -> Option { // SCRAM-SHA-256$:$: let s = input.strip_prefix("SCRAM-SHA-256$")?; let (params, keys) = s.split_once('$')?; @@ -42,7 +42,7 @@ impl ServerSecret { Some(secret) } - pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { + pub(crate) fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { // constant time to not leak partial key match client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8) } @@ -50,7 +50,7 @@ impl ServerSecret { /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. - pub fn mock(nonce: [u8; 32]) -> Self { + pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. @@ -66,7 +66,7 @@ impl ServerSecret { /// Build a new server secret from the prerequisites. /// XXX: We only use this function in tests. #[cfg(test)] - pub async fn build(password: &str) -> Option { + pub(crate) async fn build(password: &str) -> Option { Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await) } } diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index 1c2811d757..d3255cf2ca 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -4,14 +4,14 @@ use super::key::{ScramKey, SCRAM_KEY_LEN}; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] -pub struct SignatureBuilder<'a> { - pub client_first_message_bare: &'a str, - pub server_first_message: &'a str, - pub client_final_message_without_proof: &'a str, +pub(crate) struct SignatureBuilder<'a> { + pub(crate) client_first_message_bare: &'a str, + pub(crate) server_first_message: &'a str, + pub(crate) client_final_message_without_proof: &'a str, } impl SignatureBuilder<'_> { - pub fn build(&self, key: &ScramKey) -> Signature { + pub(crate) fn build(&self, key: &ScramKey) -> Signature { let parts = [ self.client_first_message_bare.as_bytes(), b",", @@ -28,13 +28,13 @@ impl SignatureBuilder<'_> { /// produces `ClientKey` that we need for authentication. #[derive(Debug)] #[repr(transparent)] -pub struct Signature { +pub(crate) struct Signature { bytes: [u8; SCRAM_KEY_LEN], } impl Signature { /// Derive `ClientKey` from client's signature and proof. - pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { + pub(crate) fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey { // This is how the proof is calculated: // // 1. sha256(ClientKey) -> StoredKey diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 8fbaecf93d..2702aeebfe 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -4,17 +4,20 @@ //! 1. Fairness per endpoint. //! 2. Yield support for high iteration counts. -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Weak, + }, + task::{Context, Poll}, }; -use crossbeam_deque::{Injector, Stealer, Worker}; -use itertools::Itertools; -use parking_lot::{Condvar, Mutex}; +use futures::FutureExt; use rand::Rng; use rand::{rngs::SmallRng, SeedableRng}; -use tokio::sync::oneshot; use crate::{ intern::EndpointIdInt, @@ -25,273 +28,164 @@ use crate::{ use super::pbkdf2::Pbkdf2; pub struct ThreadPool { - queue: Injector, - stealers: Vec>, - parkers: Vec<(Condvar, Mutex)>, - /// bitpacked representation. - /// lower 8 bits = number of sleeping threads - /// next 8 bits = number of idle threads (searching for work) - counters: AtomicU64, - + runtime: Option, pub metrics: Arc, } -#[derive(PartialEq)] -enum ThreadState { - Parked, - Active, +/// How often to reset the sketch values +const SKETCH_RESET_INTERVAL: u64 = 1021; + +thread_local! { + static STATE: RefCell> = const { RefCell::new(None) }; } impl ThreadPool { pub fn new(n_workers: u8) -> Arc { - let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec(); - let stealers = workers.iter().map(|w| w.stealer()).collect_vec(); + // rayon would be nice here, but yielding in rayon does not work well afaict. - let parkers = (0..n_workers) - .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active))) - .collect_vec(); + Arc::new_cyclic(|pool| { + let pool = pool.clone(); + let worker_id = AtomicUsize::new(0); - let pool = Arc::new(Self { - queue: Injector::new(), - stealers, - parkers, - // threads start searching for work - counters: AtomicU64::new((n_workers as u64) << 8), - metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), - }); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(n_workers as usize) + .on_thread_start(move || { + STATE.with_borrow_mut(|state| { + *state = Some(ThreadRt { + pool: pool.clone(), + id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)), + rng: SmallRng::from_entropy(), + // used to determine whether we should temporarily skip tasks for fairness. + // 99% of estimates will overcount by no more than 4096 samples + countmin: CountMinSketch::with_params( + 1.0 / (SKETCH_RESET_INTERVAL as f64), + 0.01, + ), + tick: 0, + }); + }); + }) + .build() + .unwrap(); - for (i, worker) in workers.into_iter().enumerate() { - let pool = Arc::clone(&pool); - std::thread::spawn(move || thread_rt(pool, worker, i)); - } - - pool + Self { + runtime: Some(runtime), + metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + } + }) } - pub fn spawn_job( - &self, - endpoint: EndpointIdInt, - pbkdf2: Pbkdf2, - ) -> oneshot::Receiver<[u8; 32]> { - let (tx, rx) = oneshot::channel(); - - let queue_was_empty = self.queue.is_empty(); - - self.metrics.injector_queue_depth.inc(); - self.queue.push(JobSpec { - response: tx, - pbkdf2, - endpoint, - }); - - // inspired from - let counts = self.counters.load(Ordering::SeqCst); - let num_awake_but_idle = (counts >> 8) & 0xff; - let num_sleepers = counts & 0xff; - - // If the queue is non-empty, then we always wake up a worker - // -- clearly the existing idle jobs aren't enough. Otherwise, - // check to see if we have enough idle workers. - if !queue_was_empty || num_awake_but_idle == 0 { - let num_to_wake = Ord::min(1, num_sleepers); - self.wake_any_threads(num_to_wake); - } - - rx - } - - #[cold] - fn wake_any_threads(&self, mut num_to_wake: u64) { - if num_to_wake > 0 { - for i in 0..self.parkers.len() { - if self.wake_specific_thread(i) { - num_to_wake -= 1; - if num_to_wake == 0 { - return; - } - } - } - } - } - - fn wake_specific_thread(&self, index: usize) -> bool { - let (condvar, lock) = &self.parkers[index]; - - let mut state = lock.lock(); - if *state == ThreadState::Parked { - condvar.notify_one(); - - // When the thread went to sleep, it will have incremented - // this value. When we wake it, its our job to decrement - // it. We could have the thread do it, but that would - // introduce a delay between when the thread was - // *notified* and when this counter was decremented. That - // might mislead people with new work into thinking that - // there are sleeping threads that they should try to - // wake, when in fact there is nothing left for them to - // do. - self.counters.fetch_sub(1, Ordering::SeqCst); - *state = ThreadState::Active; - - true - } else { - false - } - } - - fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker) -> Option { - // announce thread as idle - self.counters.fetch_add(256, Ordering::SeqCst); - - // try steal from the global queue - loop { - match self.queue.steal_batch_and_pop(worker) { - crossbeam_deque::Steal::Success(job) => { - self.metrics - .injector_queue_depth - .set(self.queue.len() as i64); - // no longer idle - self.counters.fetch_sub(256, Ordering::SeqCst); - return Some(job); - } - crossbeam_deque::Steal::Retry => continue, - crossbeam_deque::Steal::Empty => break, - } - } - - // try steal from our neighbours - loop { - let mut retry = false; - let start = rng.gen_range(0..self.stealers.len()); - let job = (start..self.stealers.len()) - .chain(0..start) - .filter(|i| *i != skip) - .find_map( - |victim| match self.stealers[victim].steal_batch_and_pop(worker) { - crossbeam_deque::Steal::Success(job) => Some(job), - crossbeam_deque::Steal::Empty => None, - crossbeam_deque::Steal::Retry => { - retry = true; - None - } - }, - ); - if job.is_some() { - // no longer idle - self.counters.fetch_sub(256, Ordering::SeqCst); - return job; - } - if !retry { - return None; - } - } + pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle { + JobHandle( + self.runtime + .as_ref() + .unwrap() + .spawn(JobSpec { pbkdf2, endpoint }), + ) } } -fn thread_rt(pool: Arc, worker: Worker, index: usize) { - /// interval when we should steal from the global queue - /// so that tail latencies are managed appropriately - const STEAL_INTERVAL: usize = 61; +impl Drop for ThreadPool { + fn drop(&mut self) { + self.runtime.take().unwrap().shutdown_background(); + } +} - /// How often to reset the sketch values - const SKETCH_RESET_INTERVAL: usize = 1021; +struct ThreadRt { + pool: Weak, + id: ThreadPoolWorkerId, + rng: SmallRng, + countmin: CountMinSketch, + tick: u64, +} - let mut rng = SmallRng::from_entropy(); +impl ThreadRt { + fn should_run(&mut self, job: &JobSpec) -> bool { + let rate = self + .countmin + .inc_and_return(&job.endpoint, job.pbkdf2.cost()); - // used to determine whether we should temporarily skip tasks for fairness. - // 99% of estimates will overcount by no more than 4096 samples - let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01); - - let (condvar, lock) = &pool.parkers[index]; - - 'wait: loop { - // wait for notification of work - { - let mut lock = lock.lock(); - - // queue is empty - pool.metrics - .worker_queue_depth - .set(ThreadPoolWorkerId(index), 0); - - // subtract 1 from idle count, add 1 to sleeping count. - pool.counters.fetch_sub(255, Ordering::SeqCst); - - *lock = ThreadState::Parked; - condvar.wait(&mut lock); - } - - for i in 0.. { - let Some(mut job) = worker - .pop() - .or_else(|| pool.steal(&mut rng, index, &worker)) - else { - continue 'wait; - }; - - pool.metrics - .worker_queue_depth - .set(ThreadPoolWorkerId(index), worker.len() as i64); - - // receiver is closed, cancel the task - if !job.response.is_closed() { - let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost()); - - const P: f64 = 2000.0; - // probability decreases as rate increases. - // lower probability, higher chance of being skipped - // - // estimates (rate in terms of 4096 rounds): - // rate = 0 => probability = 100% - // rate = 10 => probability = 71.3% - // rate = 50 => probability = 62.1% - // rate = 500 => probability = 52.3% - // rate = 1021 => probability = 49.8% - // - // My expectation is that the pool queue will only begin backing up at ~1000rps - // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above - // are in requests per second. - let probability = P.ln() / (P + rate as f64).ln(); - if pool.queue.len() > 32 || rng.gen_bool(probability) { - pool.metrics - .worker_task_turns_total - .inc(ThreadPoolWorkerId(index)); - - match job.pbkdf2.turn() { - std::task::Poll::Ready(result) => { - let _ = job.response.send(result); - } - std::task::Poll::Pending => worker.push(job), - } - } else { - pool.metrics - .worker_task_skips_total - .inc(ThreadPoolWorkerId(index)); - - // skip for now - worker.push(job); - } - } - - // if we get stuck with a few long lived jobs in the queue - // it's better to try and steal from the queue too for fairness - if i % STEAL_INTERVAL == 0 { - let _ = pool.queue.steal_batch(&worker); - } - - if i % SKETCH_RESET_INTERVAL == 0 { - sketch.reset(); - } - } + const P: f64 = 2000.0; + // probability decreases as rate increases. + // lower probability, higher chance of being skipped + // + // estimates (rate in terms of 4096 rounds): + // rate = 0 => probability = 100% + // rate = 10 => probability = 71.3% + // rate = 50 => probability = 62.1% + // rate = 500 => probability = 52.3% + // rate = 1021 => probability = 49.8% + // + // My expectation is that the pool queue will only begin backing up at ~1000rps + // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above + // are in requests per second. + let probability = P.ln() / (P + rate as f64).ln(); + self.rng.gen_bool(probability) } } struct JobSpec { - response: oneshot::Sender<[u8; 32]>, pbkdf2: Pbkdf2, endpoint: EndpointIdInt, } +impl Future for JobSpec { + type Output = [u8; 32]; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + STATE.with_borrow_mut(|state| { + let state = state.as_mut().expect("should be set on thread startup"); + + state.tick = state.tick.wrapping_add(1); + if state.tick % SKETCH_RESET_INTERVAL == 0 { + state.countmin.reset(); + } + + if state.should_run(&self) { + if let Some(pool) = state.pool.upgrade() { + pool.metrics.worker_task_turns_total.inc(state.id); + } + + match self.pbkdf2.turn() { + Poll::Ready(result) => Poll::Ready(result), + // more to do, we shall requeue + Poll::Pending => { + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } else { + if let Some(pool) = state.pool.upgrade() { + pool.metrics.worker_task_skips_total.inc(state.id); + } + + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + } +} + +pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); + +impl Future for JobHandle { + type Output = [u8; 32]; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.0.poll_unpin(cx) { + Poll::Ready(Ok(ok)) => Poll::Ready(ok), + Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for JobHandle { + fn drop(&mut self) { + self.0.abort(); + } +} + #[cfg(test)] mod tests { use crate::EndpointId; @@ -308,8 +202,7 @@ mod tests { let salt = [0x55; 32]; let actual = pool .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) - .await - .unwrap(); + .await; let expected = [ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index d9a9019746..84f98cb8ad 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -25,8 +25,6 @@ use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; use rand::SeedableRng; -pub use reqwest_middleware::{ClientWithMiddleware, Error}; -pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; @@ -50,7 +48,7 @@ use tokio_util::sync::CancellationToken; use tracing::{error, info, warn, Instrument}; use utils::http::error::ApiError; -pub const SERVERLESS_DRIVER_SNI: &str = "api"; +pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, @@ -178,9 +176,9 @@ pub async fn task_main( Ok(()) } -pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {} +pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {} impl AsyncReadWrite for T {} -pub type AsyncRW = Pin>; +pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 9cc271c588..f24e0478be 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -29,14 +29,14 @@ use crate::{ use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool}; -pub struct PoolingBackend { - pub pool: Arc>, - pub config: &'static ProxyConfig, - pub endpoint_rate_limiter: Arc, +pub(crate) struct PoolingBackend { + pub(crate) pool: Arc>, + pub(crate) config: &'static ProxyConfig, + pub(crate) endpoint_rate_limiter: Arc, } impl PoolingBackend { - pub async fn authenticate_with_password( + pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, config: &AuthenticationConfig, @@ -98,20 +98,20 @@ impl PoolingBackend { }) } - pub async fn authenticate_with_jwt( + pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, user_info: &ComputeUserInfo, jwt: &str, ) -> Result { match &self.config.auth_backend { - crate::auth::BackendType::Console(_, ()) => { + crate::auth::Backend::Console(_, ()) => { Err(AuthError::auth_failed("JWT login is not yet supported")) } - crate::auth::BackendType::Link(_, ()) => Err(AuthError::auth_failed( - "JWT login over link proxy is not supported", + crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed( + "JWT login over web auth proxy is not supported", )), - crate::auth::BackendType::Local(cache) => { + crate::auth::Backend::Local(cache) => { cache .jwks_cache .check_jwt( @@ -135,7 +135,7 @@ impl PoolingBackend { // we reuse the code from the usual proxy and we need to prepare few structures // that this code expects. #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] - pub async fn connect_to_compute( + pub(crate) async fn connect_to_compute( &self, ctx: &RequestMonitoring, conn_info: ConnInfo, @@ -175,7 +175,7 @@ impl PoolingBackend { } #[derive(Debug, thiserror::Error)] -pub enum HttpConnError { +pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to compute")] diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs index 390df7f4f7..7659745473 100644 --- a/proxy/src/serverless/cancel_set.rs +++ b/proxy/src/serverless/cancel_set.rs @@ -22,7 +22,7 @@ pub struct CancelSet { hasher: Hasher, } -pub struct CancelShard { +pub(crate) struct CancelShard { tokens: IndexMap, } @@ -40,7 +40,7 @@ impl CancelSet { } } - pub fn take(&self) -> Option { + pub(crate) fn take(&self) -> Option { for _ in 0..4 { if let Some(token) = self.take_raw(thread_rng().gen()) { return Some(token); @@ -50,12 +50,12 @@ impl CancelSet { None } - pub fn take_raw(&self, rng: usize) -> Option { + pub(crate) fn take_raw(&self, rng: usize) -> Option { NonZeroUsize::new(self.shards.len()) .and_then(|len| self.shards[rng % len].lock().take(rng / len)) } - pub fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> { + pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> { let shard = NonZeroUsize::new(self.shards.len()).map(|len| { let hash = self.hasher.hash_one(id) as usize; let shard = &self.shards[hash % len]; @@ -88,7 +88,7 @@ impl CancelShard { } } -pub struct CancelGuard<'a> { +pub(crate) struct CancelGuard<'a> { shard: Option<&'a Mutex>, id: Uuid, } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 476083d71e..bea599e9b9 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -30,25 +30,25 @@ use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; #[derive(Debug, Clone)] -pub struct ConnInfo { - pub user_info: ComputeUserInfo, - pub dbname: DbName, - pub auth: AuthData, +pub(crate) struct ConnInfo { + pub(crate) user_info: ComputeUserInfo, + pub(crate) dbname: DbName, + pub(crate) auth: AuthData, } #[derive(Debug, Clone)] -pub enum AuthData { +pub(crate) enum AuthData { Password(SmallVec<[u8; 16]>), Jwt(String), } impl ConnInfo { // hm, change to hasher to avoid cloning? - pub fn db_and_user(&self) -> (DbName, RoleName) { + pub(crate) fn db_and_user(&self) -> (DbName, RoleName) { (self.dbname.clone(), self.user_info.user.clone()) } - pub fn endpoint_cache_key(&self) -> Option { + pub(crate) fn endpoint_cache_key(&self) -> Option { // We don't want to cache http connections for ephemeral endpoints. if self.user_info.options.is_ephemeral() { None @@ -79,7 +79,7 @@ struct ConnPoolEntry { // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. -pub struct EndpointConnPool { +pub(crate) struct EndpointConnPool { pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, @@ -198,7 +198,7 @@ impl Drop for EndpointConnPool { } } -pub struct DbUserConnPool { +pub(crate) struct DbUserConnPool { conns: Vec>, } @@ -241,7 +241,7 @@ impl DbUserConnPool { } } -pub struct GlobalConnPool { +pub(crate) struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint @@ -282,7 +282,7 @@ pub struct GlobalConnPoolOptions { } impl GlobalConnPool { - pub fn new(config: &'static crate::config::HttpConfig) -> Arc { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { let shards = config.pool_options.pool_shards; Arc::new(Self { global_pool: DashMap::with_shard_amount(shards), @@ -293,21 +293,21 @@ impl GlobalConnPool { } #[cfg(test)] - pub fn get_global_connections_count(&self) -> usize { + pub(crate) fn get_global_connections_count(&self) -> usize { self.global_connections_count .load(atomic::Ordering::Relaxed) } - pub fn get_idle_timeout(&self) -> Duration { + pub(crate) fn get_idle_timeout(&self) -> Duration { self.config.pool_options.idle_timeout } - pub fn shutdown(&self) { + pub(crate) fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); } - pub async fn gc_worker(&self, mut rng: impl Rng) { + pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { let epoch = self.config.pool_options.gc_epoch; let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); loop { @@ -381,7 +381,7 @@ impl GlobalConnPool { } } - pub fn get( + pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, @@ -468,7 +468,7 @@ impl GlobalConnPool { } } -pub fn poll_client( +pub(crate) fn poll_client( global_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, @@ -596,7 +596,7 @@ impl Drop for ClientInner { } } -pub trait ClientInnerExt: Sync + Send + 'static { +pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; } @@ -611,13 +611,13 @@ impl ClientInnerExt for tokio_postgres::Client { } impl ClientInner { - pub fn is_closed(&self) -> bool { + pub(crate) fn is_closed(&self) -> bool { self.inner.is_closed() } } impl Client { - pub fn metrics(&self) -> Arc { + pub(crate) fn metrics(&self) -> Arc { let aux = &self.inner.as_ref().unwrap().aux; USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, @@ -626,14 +626,14 @@ impl Client { } } -pub struct Client { +pub(crate) struct Client { span: Span, inner: Option>, conn_info: ConnInfo, pool: Weak>>, } -pub struct Discard<'a, C: ClientInnerExt> { +pub(crate) struct Discard<'a, C: ClientInnerExt> { conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } @@ -651,7 +651,7 @@ impl Client { pool, } } - pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) { + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, @@ -664,13 +664,13 @@ impl Client { } impl Discard<'_, C> { - pub fn check_idle(&mut self, status: ReadyForQueryStatus) { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { info!("pool: throwing away connection '{conn_info}' because connection is not idle"); } } - pub fn discard(&mut self) { + pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index 701ab58f63..abf0ffe290 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -11,7 +11,7 @@ use serde::Serialize; use utils::http::error::ApiError; /// Like [`ApiError::into_response`] -pub fn api_error_into_response(this: ApiError) -> Response> { +pub(crate) fn api_error_into_response(this: ApiError) -> Response> { match this { ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( format!("{err:#?}"), // use debug printing so that we give the cause @@ -59,7 +59,7 @@ pub fn api_error_into_response(this: ApiError) -> Response> { /// Same as [`utils::http::error::HttpErrorBody`] #[derive(Serialize)] struct HttpErrorBody { - pub msg: String, + pub(crate) msg: String, } impl HttpErrorBody { @@ -80,7 +80,7 @@ impl HttpErrorBody { } /// Same as [`utils::http::json::json_response`] -pub fn json_response( +pub(crate) fn json_response( status: StatusCode, data: T, ) -> Result>, ApiError> { diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 3776971fa1..9f328a0e1d 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -8,7 +8,7 @@ use tokio_postgres::Row; // Convert json non-string types to strings, so that they can be passed to Postgres // as parameters. // -pub fn json_to_pg_text(json: Vec) -> Vec> { +pub(crate) fn json_to_pg_text(json: Vec) -> Vec> { json.iter().map(json_value_to_pg_text).collect() } @@ -61,7 +61,7 @@ fn json_array_to_pg_array(value: &Value) -> Option { } #[derive(Debug, thiserror::Error)] -pub enum JsonConversionError { +pub(crate) enum JsonConversionError { #[error("internal error compute returned invalid data: {0}")] AsTextError(tokio_postgres::Error), #[error("parse int error: {0}")] @@ -77,7 +77,7 @@ pub enum JsonConversionError { // // Convert postgres row with text-encoded values to JSON object // -pub fn pg_text_row_to_json( +pub(crate) fn pg_text_row_to_json( row: &Row, columns: &[Type], raw_output: bool, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 9143469eea..5b36f5e91d 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -110,7 +110,7 @@ where } #[derive(Debug, thiserror::Error)] -pub enum ConnInfoError { +pub(crate) enum ConnInfoError { #[error("invalid header: {0}")] InvalidHeader(&'static HeaderName), #[error("invalid connection string: {0}")] @@ -246,7 +246,7 @@ fn get_conn_info( } // TODO: return different http error codes -pub async fn handle( +pub(crate) async fn handle( config: &'static ProxyConfig, ctx: RequestMonitoring, request: Request, @@ -359,7 +359,7 @@ pub async fn handle( } #[derive(Debug, thiserror::Error)] -pub enum SqlOverHttpError { +pub(crate) enum SqlOverHttpError { #[error("{0}")] ReadPayload(#[from] ReadPayloadError), #[error("{0}")] @@ -413,7 +413,7 @@ impl UserFacingError for SqlOverHttpError { } #[derive(Debug, thiserror::Error)] -pub enum ReadPayloadError { +pub(crate) enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] Read(#[from] hyper1::Error), #[error("could not parse the HTTP request body: {0}")] @@ -430,7 +430,7 @@ impl ReportableError for ReadPayloadError { } #[derive(Debug, thiserror::Error)] -pub enum SqlOverHttpCancel { +pub(crate) enum SqlOverHttpCancel { #[error("query was cancelled")] Postgres, #[error("query was cancelled while stuck trying to connect to the database")] diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 4fba4d141c..3d257223b8 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -27,7 +27,7 @@ use tracing::warn; pin_project! { /// This is a wrapper around a [`WebSocketStream`] that /// implements [`AsyncRead`] and [`AsyncWrite`]. - pub struct WebSocketRw { + pub(crate) struct WebSocketRw { #[pin] stream: WebSocketServer, recv: Bytes, @@ -36,7 +36,7 @@ pin_project! { } impl WebSocketRw { - pub fn new(stream: WebSocketServer) -> Self { + pub(crate) fn new(stream: WebSocketServer) -> Self { Self { stream, recv: Bytes::new(), @@ -127,7 +127,7 @@ impl AsyncBufRead for WebSocketRw { } } -pub async fn serve_websocket( +pub(crate) async fn serve_websocket( config: &'static ProxyConfig, ctx: RequestMonitoring, websocket: OnUpgrade, diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index ef13f5fc1a..332dc27787 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -35,7 +35,7 @@ impl PqStream { } /// Get a shared reference to the underlying stream. - pub fn get_ref(&self) -> &S { + pub(crate) fn get_ref(&self) -> &S { self.framed.get_ref() } } @@ -62,7 +62,7 @@ impl PqStream { .ok_or_else(err_connection) } - pub async fn read_password_message(&mut self) -> io::Result { + pub(crate) async fn read_password_message(&mut self) -> io::Result { match self.read_message().await? { FeMessage::PasswordMessage(msg) => Ok(msg), bad => Err(io::Error::new( @@ -99,7 +99,10 @@ impl ReportableError for ReportedError { impl PqStream { /// Write the message into an internal buffer, but don't flush the underlying stream. - pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> { + pub(crate) fn write_message_noflush( + &mut self, + message: &BeMessage<'_>, + ) -> io::Result<&mut Self> { self.framed .write_message(message) .map_err(ProtocolError::into_io_error)?; @@ -114,7 +117,7 @@ impl PqStream { } /// Flush the output buffer into the underlying stream. - pub async fn flush(&mut self) -> io::Result<&mut Self> { + pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> { self.framed.flush().await?; Ok(self) } @@ -146,7 +149,7 @@ impl PqStream { /// Write the error message using [`Self::write_message`], then re-throw it. /// Trait [`UserFacingError`] acts as an allowlist for error types. - pub async fn throw_error(&mut self, error: E) -> Result + pub(crate) async fn throw_error(&mut self, error: E) -> Result where E: UserFacingError + Into, { @@ -200,7 +203,7 @@ impl Stream { } } - pub fn tls_server_end_point(&self) -> TlsServerEndPoint { + pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint { match self { Stream::Raw { .. } => TlsServerEndPoint::Undefined, Stream::Tls { diff --git a/proxy/src/url.rs b/proxy/src/url.rs index 202fe8de1f..28ac7efdfc 100644 --- a/proxy/src/url.rs +++ b/proxy/src/url.rs @@ -7,12 +7,12 @@ pub struct ApiUrl(url::Url); impl ApiUrl { /// Consume the wrapper and return inner [url](url::Url). - pub fn into_inner(self) -> url::Url { + pub(crate) fn into_inner(self) -> url::Url { self.0 } /// See [`url::Url::path_segments_mut`]. - pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> { + pub(crate) fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> { // We've already verified that it works during construction. self.0.path_segments_mut().expect("bad API url") } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 4cf6da7e2d..aa8c7ba319 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -43,12 +43,12 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// so while the project-id is unique across regions the whole pipeline will work correctly /// because we enrich the event with project_id in the control-plane endpoint. #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] -pub struct Ids { - pub endpoint_id: EndpointIdInt, - pub branch_id: BranchIdInt, +pub(crate) struct Ids { + pub(crate) endpoint_id: EndpointIdInt, + pub(crate) branch_id: BranchIdInt, } -pub trait MetricCounterRecorder { +pub(crate) trait MetricCounterRecorder { /// Record that some bytes were sent from the proxy to the client fn record_egress(&self, bytes: u64); /// Record that some connections were opened @@ -92,7 +92,7 @@ impl MetricCounterReporter for MetricBackupCounter { } #[derive(Debug)] -pub struct MetricCounter { +pub(crate) struct MetricCounter { transmitted: AtomicU64, opened_connections: AtomicUsize, backup: Arc, @@ -173,14 +173,14 @@ impl Clearable for C { type FastHasher = std::hash::BuildHasherDefault; #[derive(Default)] -pub struct Metrics { +pub(crate) struct Metrics { endpoints: DashMap, FastHasher>, backup_endpoints: DashMap, FastHasher>, } impl Metrics { /// Register a new byte metrics counter for this endpoint - pub fn register(&self, ids: Ids) -> Arc { + pub(crate) fn register(&self, ids: Ids) -> Arc { let backup = if let Some(entry) = self.backup_endpoints.get(&ids) { entry.clone() } else { @@ -215,7 +215,7 @@ impl Metrics { } } -pub static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); +pub(crate) static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result { info!("metrics collector config: {config:?}"); diff --git a/proxy/src/waiters.rs b/proxy/src/waiters.rs index 9f78242ed3..86d0f9e8b2 100644 --- a/proxy/src/waiters.rs +++ b/proxy/src/waiters.rs @@ -7,13 +7,13 @@ use thiserror::Error; use tokio::sync::oneshot; #[derive(Debug, Error)] -pub enum RegisterError { +pub(crate) enum RegisterError { #[error("Waiter `{0}` already registered")] Occupied(String), } #[derive(Debug, Error)] -pub enum NotifyError { +pub(crate) enum NotifyError { #[error("Notify failed: waiter `{0}` not registered")] NotFound(String), @@ -22,12 +22,12 @@ pub enum NotifyError { } #[derive(Debug, Error)] -pub enum WaitError { +pub(crate) enum WaitError { #[error("Wait failed: channel hangup")] Hangup, } -pub struct Waiters(pub(self) Mutex>>); +pub(crate) struct Waiters(pub(self) Mutex>>); impl Default for Waiters { fn default() -> Self { @@ -36,7 +36,7 @@ impl Default for Waiters { } impl Waiters { - pub fn register(&self, key: String) -> Result, RegisterError> { + pub(crate) fn register(&self, key: String) -> Result, RegisterError> { let (tx, rx) = oneshot::channel(); self.0 @@ -53,7 +53,7 @@ impl Waiters { }) } - pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError> + pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError> where T: Send + Sync, { @@ -79,7 +79,7 @@ impl<'a, T> Drop for DropKey<'a, T> { } pin_project! { - pub struct Waiter<'a, T> { + pub(crate) struct Waiter<'a, T> { #[pin] receiver: oneshot::Receiver, guard: DropKey<'a, T>, diff --git a/safekeeper/src/control_file.rs b/safekeeper/src/control_file.rs index c551cd3122..8b252b4ab4 100644 --- a/safekeeper/src/control_file.rs +++ b/safekeeper/src/control_file.rs @@ -7,6 +7,7 @@ use tokio::fs::File; use tokio::io::AsyncWriteExt; use utils::crashsafe::durable_rename; +use std::future::Future; use std::io::Read; use std::ops::Deref; use std::path::Path; @@ -31,10 +32,9 @@ pub const CHECKSUM_SIZE: usize = size_of::(); /// Storage should keep actual state inside of it. It should implement Deref /// trait to access state fields and have persist method for updating that state. -#[async_trait::async_trait] pub trait Storage: Deref { /// Persist safekeeper state on disk and update internal state. - async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()>; + fn persist(&mut self, s: &TimelinePersistentState) -> impl Future> + Send; /// Timestamp of last persist. fn last_persist_at(&self) -> Instant; @@ -188,7 +188,6 @@ impl TimelinePersistentState { } } -#[async_trait::async_trait] impl Storage for FileStorage { /// Persists state durably to the underlying storage. /// diff --git a/safekeeper/src/http/openapi_spec.yaml b/safekeeper/src/http/openapi_spec.yaml index a617e0310c..70999853c2 100644 --- a/safekeeper/src/http/openapi_spec.yaml +++ b/safekeeper/src/http/openapi_spec.yaml @@ -86,42 +86,6 @@ paths: default: $ref: "#/components/responses/GenericError" - /v1/tenant/{tenant_id}/timeline/{source_timeline_id}/copy: - parameters: - - name: tenant_id - in: path - required: true - schema: - type: string - format: hex - - name: source_timeline_id - in: path - required: true - schema: - type: string - format: hex - - post: - tags: - - "Timeline" - summary: Register new timeline as copy of existing timeline - description: "" - operationId: v1CopyTenantTimeline - requestBody: - content: - application/json: - schema: - $ref: "#/components/schemas/TimelineCopyRequest" - responses: - "201": - description: Timeline created - # TODO: return timeline info? - "403": - $ref: "#/components/responses/ForbiddenError" - default: - $ref: "#/components/responses/GenericError" - - /v1/tenant/{tenant_id}/timeline/{timeline_id}: parameters: - name: tenant_id @@ -179,6 +143,40 @@ paths: default: $ref: "#/components/responses/GenericError" + /v1/tenant/{tenant_id}/timeline/{source_timeline_id}/copy: + parameters: + - name: tenant_id + in: path + required: true + schema: + type: string + format: hex + - name: source_timeline_id + in: path + required: true + schema: + type: string + format: hex + + post: + tags: + - "Timeline" + summary: Register new timeline as copy of existing timeline + description: "" + operationId: v1CopyTenantTimeline + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/TimelineCopyRequest" + responses: + "201": + description: Timeline created + # TODO: return timeline info? + "403": + $ref: "#/components/responses/ForbiddenError" + default: + $ref: "#/components/responses/GenericError" /v1/record_safekeeper_info/{tenant_id}/{timeline_id}: parameters: diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index d11815f6ef..91ffa95c21 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -114,7 +114,55 @@ fn check_permission(request: &Request, tenant_id: Option) -> Res }) } +/// Deactivates all timelines for the tenant and removes its data directory. +/// See `timeline_delete_handler`. +async fn tenant_delete_handler(mut request: Request) -> Result, ApiError> { + let tenant_id = parse_request_param(&request, "tenant_id")?; + let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); + check_permission(&request, Some(tenant_id))?; + ensure_no_body(&mut request).await?; + // FIXME: `delete_force_all_for_tenant` can return an error for multiple different reasons; + // Using an `InternalServerError` should be fixed when the types support it + let delete_info = GlobalTimelines::delete_force_all_for_tenant(&tenant_id, only_local) + .await + .map_err(ApiError::InternalServerError)?; + json_response( + StatusCode::OK, + delete_info + .iter() + .map(|(ttid, resp)| (format!("{}", ttid.timeline_id), *resp)) + .collect::>(), + ) +} + +async fn timeline_create_handler(mut request: Request) -> Result, ApiError> { + let request_data: TimelineCreateRequest = json_request(&mut request).await?; + + let ttid = TenantTimelineId { + tenant_id: request_data.tenant_id, + timeline_id: request_data.timeline_id, + }; + check_permission(&request, Some(ttid.tenant_id))?; + + let server_info = ServerInfo { + pg_version: request_data.pg_version, + system_id: request_data.system_id.unwrap_or(0), + wal_seg_size: request_data.wal_seg_size.unwrap_or(WAL_SEGMENT_SIZE as u32), + }; + let local_start_lsn = request_data.local_start_lsn.unwrap_or_else(|| { + request_data + .commit_lsn + .segment_lsn(server_info.wal_seg_size as usize) + }); + GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn) + .await + .map_err(ApiError::InternalServerError)?; + + json_response(StatusCode::OK, ()) +} + /// List all (not deleted) timelines. +/// Note: it is possible to do the same with debug_dump. async fn timeline_list_handler(request: Request) -> Result, ApiError> { check_permission(&request, None)?; let res: Vec = GlobalTimelines::get_all() @@ -174,30 +222,21 @@ async fn timeline_status_handler(request: Request) -> Result) -> Result, ApiError> { - let request_data: TimelineCreateRequest = json_request(&mut request).await?; - - let ttid = TenantTimelineId { - tenant_id: request_data.tenant_id, - timeline_id: request_data.timeline_id, - }; +/// Deactivates the timeline and removes its data directory. +async fn timeline_delete_handler(mut request: Request) -> Result, ApiError> { + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); check_permission(&request, Some(ttid.tenant_id))?; - - let server_info = ServerInfo { - pg_version: request_data.pg_version, - system_id: request_data.system_id.unwrap_or(0), - wal_seg_size: request_data.wal_seg_size.unwrap_or(WAL_SEGMENT_SIZE as u32), - }; - let local_start_lsn = request_data.local_start_lsn.unwrap_or_else(|| { - request_data - .commit_lsn - .segment_lsn(server_info.wal_seg_size as usize) - }); - GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn) + ensure_no_body(&mut request).await?; + // FIXME: `delete_force` can fail from both internal errors and bad requests. Add better + // error handling here when we're able to. + let resp = GlobalTimelines::delete(&ttid, only_local) .await .map_err(ApiError::InternalServerError)?; - - json_response(StatusCode::OK, ()) + json_response(StatusCode::OK, resp) } /// Pull timeline from peer safekeeper instances. @@ -279,6 +318,46 @@ async fn timeline_copy_handler(mut request: Request) -> Result, +) -> Result, ApiError> { + check_permission(&request, None)?; + + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + + let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?; + + let patch_request: patch_control_file::Request = json_request(&mut request).await?; + let response = patch_control_file::handle_request(tli, patch_request) + .await + .map_err(ApiError::InternalServerError)?; + + json_response(StatusCode::OK, response) +} + +/// Force persist control file. +async fn timeline_checkpoint_handler(request: Request) -> Result, ApiError> { + check_permission(&request, None)?; + + let ttid = TenantTimelineId::new( + parse_request_param(&request, "tenant_id")?, + parse_request_param(&request, "timeline_id")?, + ); + + let tli = GlobalTimelines::get(ttid)?; + tli.write_shared_state() + .await + .sk + .state_mut() + .flush() + .await + .map_err(ApiError::InternalServerError)?; + json_response(StatusCode::OK, ()) +} + async fn timeline_digest_handler(request: Request) -> Result, ApiError> { let ttid = TenantTimelineId::new( parse_request_param(&request, "tenant_id")?, @@ -310,64 +389,6 @@ async fn timeline_digest_handler(request: Request) -> Result) -> Result, ApiError> { - check_permission(&request, None)?; - - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - - let tli = GlobalTimelines::get(ttid)?; - tli.write_shared_state() - .await - .sk - .state_mut() - .flush() - .await - .map_err(ApiError::InternalServerError)?; - json_response(StatusCode::OK, ()) -} - -/// Deactivates the timeline and removes its data directory. -async fn timeline_delete_handler(mut request: Request) -> Result, ApiError> { - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); - check_permission(&request, Some(ttid.tenant_id))?; - ensure_no_body(&mut request).await?; - // FIXME: `delete_force` can fail from both internal errors and bad requests. Add better - // error handling here when we're able to. - let resp = GlobalTimelines::delete(&ttid, only_local) - .await - .map_err(ApiError::InternalServerError)?; - json_response(StatusCode::OK, resp) -} - -/// Deactivates all timelines for the tenant and removes its data directory. -/// See `timeline_delete_handler`. -async fn tenant_delete_handler(mut request: Request) -> Result, ApiError> { - let tenant_id = parse_request_param(&request, "tenant_id")?; - let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false); - check_permission(&request, Some(tenant_id))?; - ensure_no_body(&mut request).await?; - // FIXME: `delete_force_all_for_tenant` can return an error for multiple different reasons; - // Using an `InternalServerError` should be fixed when the types support it - let delete_info = GlobalTimelines::delete_force_all_for_tenant(&tenant_id, only_local) - .await - .map_err(ApiError::InternalServerError)?; - json_response( - StatusCode::OK, - delete_info - .iter() - .map(|(ttid, resp)| (format!("{}", ttid.timeline_id), *resp)) - .collect::>(), - ) -} - /// Used only in tests to hand craft required data. async fn record_safekeeper_info(mut request: Request) -> Result, ApiError> { let ttid = TenantTimelineId::new( @@ -509,26 +530,6 @@ async fn dump_debug_handler(mut request: Request) -> Result Ok(response) } -async fn patch_control_file_handler( - mut request: Request, -) -> Result, ApiError> { - check_permission(&request, None)?; - - let ttid = TenantTimelineId::new( - parse_request_param(&request, "tenant_id")?, - parse_request_param(&request, "timeline_id")?, - ); - - let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?; - - let patch_request: patch_control_file::Request = json_request(&mut request).await?; - let response = patch_control_file::handle_request(tli, patch_request) - .await - .map_err(ApiError::InternalServerError)?; - - json_response(StatusCode::OK, response) -} - /// Safekeeper http router. pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder { let mut router = endpoint::make_router(); @@ -568,6 +569,9 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder failpoints_handler(r, cancel).await }) }) + .delete("/v1/tenant/:tenant_id", |r| { + request_span(r, tenant_delete_handler) + }) // Will be used in the future instead of implicit timeline creation .post("/v1/tenant/timeline", |r| { request_span(r, timeline_create_handler) @@ -581,16 +585,13 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder .delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| { request_span(r, timeline_delete_handler) }) - .delete("/v1/tenant/:tenant_id", |r| { - request_span(r, tenant_delete_handler) + .post("/v1/pull_timeline", |r| { + request_span(r, timeline_pull_handler) }) .get( "/v1/tenant/:tenant_id/timeline/:timeline_id/snapshot/:destination_id", |r| request_span(r, timeline_snapshot_handler), ) - .post("/v1/pull_timeline", |r| { - request_span(r, timeline_pull_handler) - }) .post( "/v1/tenant/:tenant_id/timeline/:source_timeline_id/copy", |r| request_span(r, timeline_copy_handler), @@ -603,14 +604,13 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder "/v1/tenant/:tenant_id/timeline/:timeline_id/checkpoint", |r| request_span(r, timeline_checkpoint_handler), ) - // for tests + .get("/v1/tenant/:tenant_id/timeline/:timeline_id/digest", |r| { + request_span(r, timeline_digest_handler) + }) .post("/v1/record_safekeeper_info/:tenant_id/:timeline_id", |r| { request_span(r, record_safekeeper_info) }) .get("/v1/debug_dump", |r| request_span(r, dump_debug_handler)) - .get("/v1/tenant/:tenant_id/timeline/:timeline_id/digest", |r| { - request_span(r, timeline_digest_handler) - }) } #[cfg(test)] diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 0814d9ba67..486954c7b9 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -971,7 +971,6 @@ mod tests { persisted_state: TimelinePersistentState, } - #[async_trait::async_trait] impl control_file::Storage for InMemoryState { async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> { self.persisted_state = s.clone(); @@ -1003,7 +1002,6 @@ mod tests { lsn: Lsn, } - #[async_trait::async_trait] impl wal_storage::Storage for DummyWalStore { fn flush_lsn(&self) -> Lsn { self.lsn diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index ded8571a3e..6fd7c91a68 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -15,6 +15,7 @@ use postgres_ffi::v14::xlog_utils::{IsPartialXLogFileName, IsXLogFileName, XLogF use postgres_ffi::{dispatch_pgversion, XLogSegNo, PG_TLI}; use remote_storage::RemotePath; use std::cmp::{max, min}; +use std::future::Future; use std::io::{self, SeekFrom}; use std::pin::Pin; use tokio::fs::{self, remove_file, File, OpenOptions}; @@ -35,7 +36,6 @@ use postgres_ffi::XLOG_BLCKSZ; use pq_proto::SystemId; use utils::{id::TenantTimelineId, lsn::Lsn}; -#[async_trait::async_trait] pub trait Storage { /// LSN of last durably stored WAL record. fn flush_lsn(&self) -> Lsn; @@ -44,16 +44,19 @@ pub trait Storage { /// the segment and short header at the page of given LSN. This is only used /// for timeline initialization because compute will stream data only since /// init_lsn. Other segment headers are included in compute stream. - async fn initialize_first_segment(&mut self, init_lsn: Lsn) -> Result<()>; + fn initialize_first_segment( + &mut self, + init_lsn: Lsn, + ) -> impl Future> + Send; /// Write piece of WAL from buf to disk, but not necessarily sync it. - async fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()>; + fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> impl Future> + Send; /// Truncate WAL at specified LSN, which must be the end of WAL record. - async fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()>; + fn truncate_wal(&mut self, end_pos: Lsn) -> impl Future> + Send; /// Durably store WAL on disk, up to the last written WAL record. - async fn flush_wal(&mut self) -> Result<()>; + fn flush_wal(&mut self) -> impl Future> + Send; /// Remove all segments <= given segno. Returns function doing that as we /// want to perform it without timeline lock. @@ -325,7 +328,6 @@ impl PhysicalStorage { } } -#[async_trait::async_trait] impl Storage for PhysicalStorage { /// flush_lsn returns LSN of last durably stored WAL record. fn flush_lsn(&self) -> Lsn { diff --git a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs index c2db9de78a..6b31edb1f2 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs @@ -83,7 +83,6 @@ impl DiskStateStorage { } } -#[async_trait::async_trait] impl control_file::Storage for DiskStateStorage { /// Persist safekeeper state on disk and update internal state. async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> { @@ -175,7 +174,6 @@ impl DiskWALStorage { } } -#[async_trait::async_trait] impl wal_storage::Storage for DiskWALStorage { /// LSN of last durably stored WAL record. fn flush_lsn(&self) -> Lsn { diff --git a/storage_controller/client/Cargo.toml b/storage_controller/client/Cargo.toml index c3bfe2bfd2..e7a4264fd0 100644 --- a/storage_controller/client/Cargo.toml +++ b/storage_controller/client/Cargo.toml @@ -8,7 +8,6 @@ license.workspace = true pageserver_api.workspace = true pageserver_client.workspace = true thiserror.workspace = true -async-trait.workspace = true reqwest.workspace = true utils.workspace = true serde.workspace = true diff --git a/storage_controller/migrations/2024-08-27-184400_pageserver_az/down.sql b/storage_controller/migrations/2024-08-27-184400_pageserver_az/down.sql new file mode 100644 index 0000000000..22df81c83c --- /dev/null +++ b/storage_controller/migrations/2024-08-27-184400_pageserver_az/down.sql @@ -0,0 +1 @@ +ALTER TABLE nodes DROP availability_zone_id; diff --git a/storage_controller/migrations/2024-08-27-184400_pageserver_az/up.sql b/storage_controller/migrations/2024-08-27-184400_pageserver_az/up.sql new file mode 100644 index 0000000000..7112f92bf2 --- /dev/null +++ b/storage_controller/migrations/2024-08-27-184400_pageserver_az/up.sql @@ -0,0 +1 @@ +ALTER TABLE nodes ADD availability_zone_id VARCHAR; diff --git a/storage_controller/src/node.rs b/storage_controller/src/node.rs index 61a44daca9..73cecc491d 100644 --- a/storage_controller/src/node.rs +++ b/storage_controller/src/node.rs @@ -36,6 +36,8 @@ pub(crate) struct Node { listen_pg_addr: String, listen_pg_port: u16, + availability_zone_id: Option, + // This cancellation token means "stop any RPCs in flight to this node, and don't start // any more". It is not related to process shutdown. #[serde(skip)] @@ -61,6 +63,10 @@ impl Node { self.id } + pub(crate) fn get_availability_zone_id(&self) -> Option<&str> { + self.availability_zone_id.as_deref() + } + pub(crate) fn get_scheduling(&self) -> NodeSchedulingPolicy { self.scheduling } @@ -72,7 +78,18 @@ impl Node { /// Does this registration request match `self`? This is used when deciding whether a registration /// request should be allowed to update an existing record with the same node ID. pub(crate) fn registration_match(&self, register_req: &NodeRegisterRequest) -> bool { - self.id == register_req.node_id + let az_ids_match = { + match ( + self.availability_zone_id.as_deref(), + register_req.availability_zone_id.as_deref(), + ) { + (Some(current_az), Some(register_req_az)) => current_az == register_req_az, + _ => true, + } + }; + + az_ids_match + && self.id == register_req.node_id && self.listen_http_addr == register_req.listen_http_addr && self.listen_http_port == register_req.listen_http_port && self.listen_pg_addr == register_req.listen_pg_addr @@ -173,6 +190,7 @@ impl Node { listen_http_port: u16, listen_pg_addr: String, listen_pg_port: u16, + availability_zone_id: Option, ) -> Self { Self { id, @@ -182,6 +200,7 @@ impl Node { listen_pg_port, scheduling: NodeSchedulingPolicy::Active, availability: NodeAvailability::Offline, + availability_zone_id, cancel: CancellationToken::new(), } } @@ -194,6 +213,7 @@ impl Node { listen_http_port: self.listen_http_port as i32, listen_pg_addr: self.listen_pg_addr.clone(), listen_pg_port: self.listen_pg_port as i32, + availability_zone_id: self.availability_zone_id.clone(), } } @@ -208,6 +228,7 @@ impl Node { listen_http_port: np.listen_http_port as u16, listen_pg_addr: np.listen_pg_addr, listen_pg_port: np.listen_pg_port as u16, + availability_zone_id: np.availability_zone_id, cancel: CancellationToken::new(), } } diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index 1a905753a1..a842079ce7 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -103,6 +103,7 @@ pub(crate) enum DatabaseOperation { ListMetadataHealthOutdated, GetLeader, UpdateLeader, + SetNodeAzId, } #[must_use] @@ -315,6 +316,31 @@ impl Persistence { } } + pub(crate) async fn set_node_availability_zone_id( + &self, + input_node_id: NodeId, + input_az_id: String, + ) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + let updated = self + .with_measured_conn(DatabaseOperation::SetNodeAzId, move |conn| { + let updated = diesel::update(nodes) + .filter(node_id.eq(input_node_id.0 as i64)) + .set((availability_zone_id.eq(input_az_id.clone()),)) + .execute(conn)?; + Ok(updated) + }) + .await?; + + if updated != 1 { + Err(DatabaseError::Logical(format!( + "Node {node_id:?} not found for setting az id", + ))) + } else { + Ok(()) + } + } + /// At startup, load the high level state for shards, such as their config + policy. This will /// be enriched at runtime with state discovered on pageservers. pub(crate) async fn list_tenant_shards(&self) -> DatabaseResult> { @@ -974,6 +1000,7 @@ pub(crate) struct NodePersistence { pub(crate) listen_http_port: i32, pub(crate) listen_pg_addr: String, pub(crate) listen_pg_port: i32, + pub(crate) availability_zone_id: Option, } /// Tenant metadata health status that are stored durably. diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index 94db879ade..102a3124d2 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -12,6 +12,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; +use utils::backoff::exponential_backoff; use utils::failpoint_support; use utils::generation::Generation; use utils::id::{NodeId, TimelineId}; @@ -568,6 +569,7 @@ impl Reconciler { // During a live migration it is unhelpful to proceed if we couldn't notify compute: if we detach // the origin without notifying compute, we will render the tenant unavailable. + let mut notify_attempts = 0; while let Err(e) = self.compute_notify().await { match e { NotifyError::Fatal(_) => return Err(ReconcileError::Notify(e)), @@ -578,6 +580,17 @@ impl Reconciler { ); } } + + exponential_backoff( + notify_attempts, + // Generous waits: control plane operations which might be blocking us usually complete on the order + // of hundreds to thousands of milliseconds, so no point busy polling. + 1.0, + 10.0, + &self.cancel, + ) + .await; + notify_attempts += 1; } // Downgrade the origin to secondary. If the tenant's policy is PlacementPolicy::Attached(0), then diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index 060e3cc6ca..ef4da6861c 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -528,6 +528,7 @@ pub(crate) mod test_utils { 80 + i as u16, format!("pghost-{i}"), 5432 + i as u16, + None, ); node.set_availability(NodeAvailability::Active(test_utilization::simple(0, 0))); assert!(node.is_available()); diff --git a/storage_controller/src/schema.rs b/storage_controller/src/schema.rs index 77ba47e114..1e8379500c 100644 --- a/storage_controller/src/schema.rs +++ b/storage_controller/src/schema.rs @@ -25,6 +25,7 @@ diesel::table! { listen_http_port -> Int4, listen_pg_addr -> Varchar, listen_pg_port -> Int4, + availability_zone_id -> Nullable, } } diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 7daa1e4f5f..1f221a9b45 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -1257,6 +1257,7 @@ impl Service { 123, "".to_string(), 123, + None, ); scheduler.node_upsert(&node); @@ -4683,29 +4684,84 @@ impl Service { ) .await; - { + if register_req.availability_zone_id.is_none() { + tracing::warn!( + "Node {} registering without specific availability zone id", + register_req.node_id + ); + } + + enum RegistrationStatus { + Matched(Node), + Mismatched, + New, + } + + let registration_status = { let locked = self.inner.read().unwrap(); if let Some(node) = locked.nodes.get(®ister_req.node_id) { - // Note that we do not do a total equality of the struct, because we don't require - // the availability/scheduling states to agree for a POST to be idempotent. if node.registration_match(®ister_req) { - tracing::info!( - "Node {} re-registered with matching address", - register_req.node_id - ); - return Ok(()); + RegistrationStatus::Matched(node.clone()) } else { - // TODO: decide if we want to allow modifying node addresses without removing and re-adding - // the node. Safest/simplest thing is to refuse it, and usually we deploy with - // a fixed address through the lifetime of a node. - tracing::warn!( - "Node {} tried to register with different address", - register_req.node_id - ); - return Err(ApiError::Conflict( - "Node is already registered with different address".to_string(), - )); + RegistrationStatus::Mismatched } + } else { + RegistrationStatus::New + } + }; + + match registration_status { + RegistrationStatus::Matched(node) => { + tracing::info!( + "Node {} re-registered with matching address", + register_req.node_id + ); + + if node.get_availability_zone_id().is_none() { + if let Some(az_id) = register_req.availability_zone_id.clone() { + tracing::info!("Extracting availability zone id from registration request for node {}: {}", + register_req.node_id, az_id); + + // Persist to the database and update in memory state. See comment below + // on ordering. + self.persistence + .set_node_availability_zone_id(register_req.node_id, az_id) + .await?; + let node_with_az = Node::new( + register_req.node_id, + register_req.listen_http_addr, + register_req.listen_http_port, + register_req.listen_pg_addr, + register_req.listen_pg_port, + register_req.availability_zone_id, + ); + + let mut locked = self.inner.write().unwrap(); + let mut new_nodes = (*locked.nodes).clone(); + + locked.scheduler.node_upsert(&node_with_az); + new_nodes.insert(register_req.node_id, node_with_az); + + locked.nodes = Arc::new(new_nodes); + } + } + + return Ok(()); + } + RegistrationStatus::Mismatched => { + // TODO: decide if we want to allow modifying node addresses without removing and re-adding + // the node. Safest/simplest thing is to refuse it, and usually we deploy with + // a fixed address through the lifetime of a node. + tracing::warn!( + "Node {} tried to register with different address", + register_req.node_id + ); + return Err(ApiError::Conflict( + "Node is already registered with different address".to_string(), + )); + } + RegistrationStatus::New => { + // fallthrough } } @@ -4742,6 +4798,7 @@ impl Service { register_req.listen_http_port, register_req.listen_pg_addr, register_req.listen_pg_port, + register_req.availability_zone_id, ); // TODO: idempotency if the node already exists in the database diff --git a/storage_scrubber/src/checks.rs b/storage_scrubber/src/checks.rs index 08b0f06ebf..15dfb101b5 100644 --- a/storage_scrubber/src/checks.rs +++ b/storage_scrubber/src/checks.rs @@ -1,6 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use anyhow::Context; +use itertools::Itertools; use pageserver::tenant::layer_map::LayerMap; use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata; use pageserver_api::shard::ShardIndex; @@ -47,6 +48,56 @@ impl TimelineAnalysis { } } +/// Checks whether a layer map is valid (i.e., is a valid result of the current compaction algorithm if nothing goes wrong). +/// The function checks if we can split the LSN range of a delta layer only at the LSNs of the delta layers. For example, +/// +/// ```plain +/// | | | | +/// | 1 | | 2 | | 3 | +/// | | | | | | +/// ``` +/// +/// This is not a valid layer map because the LSN range of layer 1 intersects with the LSN range of layer 2. 1 and 2 should have +/// the same LSN range. +/// +/// The exception is that when layer 2 only contains a single key, it could be split over the LSN range. For example, +/// +/// ```plain +/// | | | 2 | | | +/// | 1 | |-------| | 3 | +/// | | | 4 | | | +/// +/// If layer 2 and 4 contain the same single key, this is also a valid layer map. +fn check_valid_layermap(metadata: &HashMap) -> Option { + let mut lsn_split_point = BTreeSet::new(); // TODO: use a better data structure (range tree / range set?) + let mut all_delta_layers = Vec::new(); + for (name, _) in metadata.iter() { + if let LayerName::Delta(layer) = name { + if layer.key_range.start.next() != layer.key_range.end { + all_delta_layers.push(layer.clone()); + } + } + } + for layer in &all_delta_layers { + let lsn_range = &layer.lsn_range; + lsn_split_point.insert(lsn_range.start); + lsn_split_point.insert(lsn_range.end); + } + for layer in &all_delta_layers { + let lsn_range = layer.lsn_range.clone(); + let intersects = lsn_split_point.range(lsn_range).collect_vec(); + if intersects.len() > 1 { + let err = format!( + "layer violates the layer map LSN split assumption: layer {} intersects with LSN [{}]", + layer, + intersects.into_iter().map(|lsn| lsn.to_string()).join(", ") + ); + return Some(err); + } + } + None +} + pub(crate) async fn branch_cleanup_and_check_errors( remote_client: &GenericRemoteStorage, id: &TenantShardTimelineId, @@ -126,6 +177,12 @@ pub(crate) async fn branch_cleanup_and_check_errors( } } + if let Some(err) = check_valid_layermap(&index_part.layer_metadata) { + result.errors.push(format!( + "index_part.json contains invalid layer map structure: {err}" + )); + } + for (layer, metadata) in index_part.layer_metadata { if metadata.file_size == 0 { result.errors.push(format!( diff --git a/storage_scrubber/src/lib.rs b/storage_scrubber/src/lib.rs index 112f052e07..3c21d2f8cf 100644 --- a/storage_scrubber/src/lib.rs +++ b/storage_scrubber/src/lib.rs @@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize}; use storage_controller_client::control_api; use tokio::io::AsyncReadExt; use tokio_util::sync::CancellationToken; -use tracing::error; +use tracing::{error, warn}; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use utils::fs_ext; @@ -466,7 +466,7 @@ async fn list_objects_with_retries( return Err(e) .with_context(|| format!("Failed to list objects {MAX_RETRIES} times")); } - error!( + warn!( "list_objects_v2 query failed: bucket_name={}, prefix={}, delimiter={}, error={}", s3_target.bucket_name, s3_target.prefix_in_bucket, diff --git a/storage_scrubber/src/main.rs b/storage_scrubber/src/main.rs index 3935e513e3..c5961753c5 100644 --- a/storage_scrubber/src/main.rs +++ b/storage_scrubber/src/main.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, bail}; +use anyhow::{anyhow, bail, Context}; use camino::Utf8PathBuf; use pageserver_api::controller_api::{MetadataHealthUpdateRequest, MetadataHealthUpdateResponse}; use pageserver_api::shard::TenantShardId; @@ -7,6 +7,7 @@ use storage_controller_client::control_api; use storage_scrubber::garbage::{find_garbage, purge_garbage, PurgeMode}; use storage_scrubber::pageserver_physical_gc::GcMode; use storage_scrubber::scan_pageserver_metadata::scan_pageserver_metadata; +use storage_scrubber::scan_safekeeper_metadata::DatabaseOrList; use storage_scrubber::tenant_snapshot::SnapshotDownloader; use storage_scrubber::{find_large_objects, ControllerClientConfig}; use storage_scrubber::{ @@ -76,6 +77,9 @@ enum Command { /// For safekeeper node_kind only, table in the db with debug dump #[arg(long, default_value = None)] dump_db_table: Option, + /// For safekeeper node_kind only, json list of timelines and their lsn info + #[arg(long, default_value = None)] + timeline_lsns: Option, }, TenantSnapshot { #[arg(long = "tenant-id")] @@ -155,20 +159,22 @@ async fn main() -> anyhow::Result<()> { post_to_storcon, dump_db_connstr, dump_db_table, + timeline_lsns, } => { if let NodeKind::Safekeeper = node_kind { - let dump_db_connstr = - dump_db_connstr.ok_or(anyhow::anyhow!("dump_db_connstr not specified"))?; - let dump_db_table = - dump_db_table.ok_or(anyhow::anyhow!("dump_db_table not specified"))?; - - let summary = scan_safekeeper_metadata( - bucket_config.clone(), - tenant_ids.iter().map(|tshid| tshid.tenant_id).collect(), - dump_db_connstr, - dump_db_table, - ) - .await?; + let db_or_list = match (timeline_lsns, dump_db_connstr) { + (Some(timeline_lsns), _) => { + let timeline_lsns = serde_json::from_str(&timeline_lsns).context("parsing timeline_lsns")?; + DatabaseOrList::List(timeline_lsns) + } + (None, Some(dump_db_connstr)) => { + let dump_db_table = dump_db_table.ok_or_else(|| anyhow::anyhow!("dump_db_table not specified"))?; + let tenant_ids = tenant_ids.iter().map(|tshid| tshid.tenant_id).collect(); + DatabaseOrList::Database { tenant_ids, connstr: dump_db_connstr, table: dump_db_table } + } + (None, None) => anyhow::bail!("neither `timeline_lsns` specified, nor `dump_db_connstr` and `dump_db_table`"), + }; + let summary = scan_safekeeper_metadata(bucket_config.clone(), db_or_list).await?; if json { println!("{}", serde_json::to_string(&summary).unwrap()) } else { diff --git a/storage_scrubber/src/scan_safekeeper_metadata.rs b/storage_scrubber/src/scan_safekeeper_metadata.rs index 1a9f3d0ef5..15f3665fac 100644 --- a/storage_scrubber/src/scan_safekeeper_metadata.rs +++ b/storage_scrubber/src/scan_safekeeper_metadata.rs @@ -7,7 +7,7 @@ use postgres_ffi::{XLogFileName, PG_TLI}; use remote_storage::GenericRemoteStorage; use serde::Serialize; use tokio_postgres::types::PgLsn; -use tracing::{error, info, trace}; +use tracing::{debug, error, info}; use utils::{ id::{TenantId, TenantTimelineId, TimelineId}, lsn::Lsn, @@ -54,6 +54,23 @@ impl MetadataSummary { } } +#[derive(serde::Deserialize)] +pub struct TimelineLsnData { + tenant_id: String, + timeline_id: String, + timeline_start_lsn: Lsn, + backup_lsn: Lsn, +} + +pub enum DatabaseOrList { + Database { + tenant_ids: Vec, + connstr: String, + table: String, + }, + List(Vec), +} + /// Scan the safekeeper metadata in an S3 bucket, reporting errors and /// statistics. /// @@ -63,68 +80,39 @@ impl MetadataSummary { /// the project wasn't deleted in the meanwhile. pub async fn scan_safekeeper_metadata( bucket_config: BucketConfig, - tenant_ids: Vec, - dump_db_connstr: String, - dump_db_table: String, + db_or_list: DatabaseOrList, ) -> anyhow::Result { info!( - "checking bucket {}, region {}, dump_db_table {}", - bucket_config.bucket, bucket_config.region, dump_db_table + "checking bucket {}, region {}", + bucket_config.bucket, bucket_config.region ); - // Use rustls (Neon requires TLS) - let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone(); - let client_config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - let tls_connector = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); - let (client, connection) = tokio_postgres::connect(&dump_db_connstr, tls_connector).await?; - // The connection object performs the actual communication with the database, - // so spawn it off to run on its own. - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {}", e); - } - }); - - let tenant_filter_clause = if !tenant_ids.is_empty() { - format!( - "and tenant_id in ({})", - tenant_ids - .iter() - .map(|t| format!("'{}'", t)) - .collect::>() - .join(", ") - ) - } else { - "".to_owned() - }; - let query = format!( - "select tenant_id, timeline_id, min(timeline_start_lsn), max(backup_lsn) from \"{}\" where not is_cancelled {} group by tenant_id, timeline_id;", - dump_db_table, tenant_filter_clause, - ); - info!("query is {}", query); - let timelines = client.query(&query, &[]).await?; - info!("loaded {} timelines", timelines.len()); let (remote_client, target) = init_remote(bucket_config, NodeKind::Safekeeper).await?; let console_config = ConsoleConfig::from_env()?; let cloud_admin_api_client = CloudAdminApiClient::new(console_config); - let checks = futures::stream::iter(timelines.iter().map(Ok)).map_ok(|row| { - let tenant_id = TenantId::from_str(row.get(0)).expect("failed to parse tenant_id"); - let timeline_id = TimelineId::from_str(row.get(1)).expect("failed to parse tenant_id"); - let timeline_start_lsn_pg: PgLsn = row.get(2); - let timeline_start_lsn: Lsn = Lsn(u64::from(timeline_start_lsn_pg)); - let backup_lsn_pg: PgLsn = row.get(3); - let backup_lsn: Lsn = Lsn(u64::from(backup_lsn_pg)); + let timelines = match db_or_list { + DatabaseOrList::Database { + tenant_ids, + connstr, + table, + } => load_timelines_from_db(tenant_ids, connstr, table).await?, + DatabaseOrList::List(list) => list, + }; + info!("loaded {} timelines", timelines.len()); + + let checks = futures::stream::iter(timelines.into_iter().map(Ok)).map_ok(|timeline| { + let tenant_id = TenantId::from_str(&timeline.tenant_id).expect("failed to parse tenant_id"); + let timeline_id = + TimelineId::from_str(&timeline.timeline_id).expect("failed to parse tenant_id"); let ttid = TenantTimelineId::new(tenant_id, timeline_id); check_timeline( &remote_client, &target, &cloud_admin_api_client, ttid, - timeline_start_lsn, - backup_lsn, + timeline.timeline_start_lsn, + timeline.backup_lsn, ) }); // Run multiple check_timeline's concurrently. @@ -163,11 +151,9 @@ async fn check_timeline( timeline_start_lsn: Lsn, backup_lsn: Lsn, ) -> anyhow::Result { - trace!( + debug!( "checking ttid {}, should contain WAL [{}-{}]", - ttid, - timeline_start_lsn, - backup_lsn + ttid, timeline_start_lsn, backup_lsn ); // calculate expected segfiles let expected_first_segno = timeline_start_lsn.segment_number(WAL_SEGSIZE); @@ -177,7 +163,7 @@ async fn check_timeline( .map(|segno| XLogFileName(PG_TLI, segno, WAL_SEGSIZE)), ); let expected_files_num = expected_segfiles.len(); - trace!("expecting {} files", expected_segfiles.len(),); + debug!("expecting {} files", expected_segfiles.len(),); // now list s3 and check if it misses something let ttshid = @@ -252,3 +238,65 @@ fn load_certs() -> Result, std::io::Error> { Ok(Arc::new(store)) } static TLS_ROOTS: OnceCell> = OnceCell::new(); + +async fn load_timelines_from_db( + tenant_ids: Vec, + dump_db_connstr: String, + dump_db_table: String, +) -> anyhow::Result> { + info!("loading from table {dump_db_table}"); + + // Use rustls (Neon requires TLS) + let root_store = TLS_ROOTS.get_or_try_init(load_certs)?.clone(); + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + let tls_connector = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); + let (client, connection) = tokio_postgres::connect(&dump_db_connstr, tls_connector).await?; + // The connection object performs the actual communication with the database, + // so spawn it off to run on its own. + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + let tenant_filter_clause = if !tenant_ids.is_empty() { + format!( + "and tenant_id in ({})", + tenant_ids + .iter() + .map(|t| format!("'{}'", t)) + .collect::>() + .join(", ") + ) + } else { + "".to_owned() + }; + let query = format!( + "select tenant_id, timeline_id, min(timeline_start_lsn), max(backup_lsn) \ + from \"{dump_db_table}\" \ + where not is_cancelled {tenant_filter_clause} \ + group by tenant_id, timeline_id;" + ); + info!("query is {}", query); + let timelines = client.query(&query, &[]).await?; + + let timelines = timelines + .into_iter() + .map(|row| { + let tenant_id = row.get(0); + let timeline_id = row.get(1); + let timeline_start_lsn_pg: PgLsn = row.get(2); + let backup_lsn_pg: PgLsn = row.get(3); + + TimelineLsnData { + tenant_id, + timeline_id, + timeline_start_lsn: Lsn(u64::from(timeline_start_lsn_pg)), + backup_lsn: Lsn(u64::from(backup_lsn_pg)), + } + }) + .collect::>(); + Ok(timelines) +} diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 5fe544b3bd..98a9dd7184 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -102,7 +102,6 @@ class NeonCompare(PgCompare): zenbenchmark: NeonBenchmarker, neon_simple_env: NeonEnv, pg_bin: PgBin, - branch_name: str, ): self.env = neon_simple_env self._zenbenchmark = zenbenchmark @@ -110,16 +109,11 @@ class NeonCompare(PgCompare): self.pageserver_http_client = self.env.pageserver.http_client() # note that neon_simple_env now uses LOCAL_FS remote storage - - # Create tenant - tenant_conf: Dict[str, str] = {} - self.tenant, _ = self.env.neon_cli.create_tenant(conf=tenant_conf) - - # Create timeline - self.timeline = self.env.neon_cli.create_timeline(branch_name, tenant_id=self.tenant) + self.tenant = self.env.initial_tenant + self.timeline = self.env.initial_timeline # Start pg - self._pg = self.env.endpoints.create_start(branch_name, "main", self.tenant) + self._pg = self.env.endpoints.create_start("main", "main", self.tenant) @property def pg(self) -> PgProtocol: @@ -297,13 +291,11 @@ class RemoteCompare(PgCompare): @pytest.fixture(scope="function") def neon_compare( - request: FixtureRequest, zenbenchmark: NeonBenchmarker, pg_bin: PgBin, neon_simple_env: NeonEnv, ) -> NeonCompare: - branch_name = request.node.name - return NeonCompare(zenbenchmark, neon_simple_env, pg_bin, branch_name) + return NeonCompare(zenbenchmark, neon_simple_env, pg_bin) @pytest.fixture(scope="function") diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 92febfec9b..800ae03d13 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -496,6 +496,7 @@ class NeonEnvBuilder: pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = None, safekeeper_extra_opts: Optional[list[str]] = None, storage_controller_port_override: Optional[int] = None, + pageserver_io_buffer_alignment: Optional[int] = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -550,6 +551,8 @@ class NeonEnvBuilder: self.storage_controller_port_override = storage_controller_port_override + self.pageserver_io_buffer_alignment = pageserver_io_buffer_alignment + assert test_name.startswith( "test_" ), "Unexpectedly instantiated from outside a test function" @@ -1123,6 +1126,7 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_aux_file_policy = config.pageserver_aux_file_policy + self.pageserver_io_buffer_alignment = config.pageserver_io_buffer_alignment # Create the neon_local's `NeonLocalInitConf` cfg: Dict[str, Any] = { @@ -1184,6 +1188,8 @@ class NeonEnv: for key, value in override.items(): ps_cfg[key] = value + ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment + # Create a corresponding NeonPageserver object self.pageservers.append( NeonPageserver( @@ -1425,6 +1431,7 @@ def _shared_simple_env( pageserver_virtual_file_io_engine: str, pageserver_aux_file_policy: Optional[AuxFileStore], pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]], + pageserver_io_buffer_alignment: Optional[int], ) -> Iterator[NeonEnv]: """ # Internal fixture backing the `neon_simple_env` fixture. If TEST_SHARED_FIXTURES @@ -1457,6 +1464,7 @@ def _shared_simple_env( pageserver_virtual_file_io_engine=pageserver_virtual_file_io_engine, pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, + pageserver_io_buffer_alignment=pageserver_io_buffer_alignment, ) as builder: env = builder.init_start() @@ -1499,6 +1507,7 @@ def neon_env_builder( pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]], pageserver_aux_file_policy: Optional[AuxFileStore], record_property: Callable[[str, object], None], + pageserver_io_buffer_alignment: Optional[int], ) -> Iterator[NeonEnvBuilder]: """ Fixture to create a Neon environment for test. @@ -1534,6 +1543,7 @@ def neon_env_builder( test_overlay_dir=test_overlay_dir, pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, + pageserver_io_buffer_alignment=pageserver_io_buffer_alignment, ) as builder: yield builder # Propogate `preserve_database_files` to make it possible to use in other fixtures, @@ -4615,12 +4625,20 @@ class Safekeeper(LogUtils): wait_until(20, 0.5, paused) +# TODO: Replace with `StrEnum` when we upgrade to python 3.11 +class NodeKind(str, Enum): + PAGESERVER = "pageserver" + SAFEKEEPER = "safekeeper" + + class StorageScrubber: def __init__(self, env: NeonEnv, log_dir: Path): self.env = env self.log_dir = log_dir - def scrubber_cli(self, args: list[str], timeout) -> str: + def scrubber_cli( + self, args: list[str], timeout, extra_env: Optional[Dict[str, str]] = None + ) -> str: assert isinstance(self.env.pageserver_remote_storage, S3Storage) s3_storage = self.env.pageserver_remote_storage @@ -4635,6 +4653,9 @@ class StorageScrubber: if s3_storage.endpoint is not None: env.update({"AWS_ENDPOINT_URL": s3_storage.endpoint}) + if extra_env is not None: + env.update(extra_env) + base_args = [ str(self.env.neon_binpath / "storage_scrubber"), f"--controller-api={self.env.storage_controller.api_root()}", @@ -4662,18 +4683,43 @@ class StorageScrubber: assert stdout is not None return stdout - def scan_metadata(self, post_to_storage_controller: bool = False) -> Tuple[bool, Any]: + def scan_metadata_safekeeper( + self, + timeline_lsns: List[Dict[str, Any]], + cloud_admin_api_url: str, + cloud_admin_api_token: str, + ) -> Tuple[bool, Any]: + extra_env = { + "CLOUD_ADMIN_API_URL": cloud_admin_api_url, + "CLOUD_ADMIN_API_TOKEN": cloud_admin_api_token, + } + return self.scan_metadata( + node_kind=NodeKind.SAFEKEEPER, timeline_lsns=timeline_lsns, extra_env=extra_env + ) + + def scan_metadata( + self, + post_to_storage_controller: bool = False, + node_kind: NodeKind = NodeKind.PAGESERVER, + timeline_lsns: Optional[List[Dict[str, Any]]] = None, + extra_env: Optional[Dict[str, str]] = None, + ) -> Tuple[bool, Any]: """ Returns the health status and the metadata summary. """ - args = ["scan-metadata", "--node-kind", "pageserver", "--json"] + args = ["scan-metadata", "--node-kind", node_kind.value, "--json"] if post_to_storage_controller: args.append("--post") - stdout = self.scrubber_cli(args, timeout=30) + if timeline_lsns is not None: + args.append("--timeline-lsns") + args.append(json.dumps(timeline_lsns)) + stdout = self.scrubber_cli(args, timeout=30, extra_env=extra_env) try: summary = json.loads(stdout) - healthy = not summary["with_errors"] and not summary["with_warnings"] + # summary does not contain "with_warnings" if node_kind is the safekeeper + no_warnings = "with_warnings" not in summary or not summary["with_warnings"] + healthy = not summary["with_errors"] and no_warnings return healthy, summary except: log.error("Failed to decode JSON output from `scan-metadata`. Dumping stdout:") diff --git a/test_runner/fixtures/pageserver/allowed_errors.py b/test_runner/fixtures/pageserver/allowed_errors.py index dff002bd4b..70f2676245 100755 --- a/test_runner/fixtures/pageserver/allowed_errors.py +++ b/test_runner/fixtures/pageserver/allowed_errors.py @@ -52,9 +52,6 @@ DEFAULT_PAGESERVER_ALLOWED_ERRORS = ( ".*Error processing HTTP request: Forbidden", # intentional failpoints ".*failpoint ", - # FIXME: These need investigation - ".*manual_gc.*is_shutdown_requested\\(\\) called in an unexpected task or thread.*", - ".*tenant_list: timeline is not found in remote index while it is present in the tenants registry.*", # Tenant::delete_timeline() can cause any of the four following errors. # FIXME: we shouldn't be considering it an error: https://github.com/neondatabase/neon/issues/2946 ".*could not flush frozen layer.*queue is in state Stopped", # when schedule layer upload fails because queued got closed before compaction got killed @@ -112,6 +109,9 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [ # controller's attempts to notify the endpoint). ".*reconciler.*neon_local notification hook failed.*", ".*reconciler.*neon_local error.*", + # Neon local does not provide pageserver with an AZ + # TODO: remove this once neon local does so + ".*registering without specific availability zone id.*", ] diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 92c98763e3..e2dd51802c 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -34,6 +34,11 @@ def pageserver_virtual_file_io_engine() -> Optional[str]: return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE") +@pytest.fixture(scope="function", autouse=True) +def pageserver_io_buffer_alignment() -> Optional[int]: + return None + + @pytest.fixture(scope="function", autouse=True) def pageserver_aux_file_policy() -> Optional[AuxFileStore]: return None diff --git a/test_runner/fixtures/safekeeper/http.py b/test_runner/fixtures/safekeeper/http.py index dd3a0a3d54..05b43cfb72 100644 --- a/test_runner/fixtures/safekeeper/http.py +++ b/test_runner/fixtures/safekeeper/http.py @@ -65,6 +65,16 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): def check_status(self): self.get(f"http://localhost:{self.port}/v1/status").raise_for_status() + def get_metrics_str(self) -> str: + """You probably want to use get_metrics() instead.""" + request_result = self.get(f"http://localhost:{self.port}/metrics") + request_result.raise_for_status() + return request_result.text + + def get_metrics(self) -> SafekeeperMetrics: + res = self.get_metrics_str() + return SafekeeperMetrics(parse_metrics(res)) + def is_testing_enabled_or_skip(self): if not self.is_testing_enabled: pytest.skip("safekeeper was built without 'testing' feature") @@ -89,56 +99,8 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert res_json is None return res_json - def debug_dump(self, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - params = params or {} - res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) - res.raise_for_status() - res_json = json.loads(res.text) - assert isinstance(res_json, dict) - return res_json - - def patch_control_file( - self, - tenant_id: TenantId, - timeline_id: TimelineId, - patch: Dict[str, Any], - ) -> Dict[str, Any]: - res = self.patch( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file", - json={ - "updates": patch, - "apply_fields": list(patch.keys()), - }, - ) - res.raise_for_status() - res_json = res.json() - assert isinstance(res_json, dict) - return res_json - - def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]: - res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body) - res.raise_for_status() - res_json = res.json() - assert isinstance(res_json, dict) - return res_json - - def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: Dict[str, Any]): - res = self.post( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/copy", - json=body, - ) - res.raise_for_status() - - def timeline_digest( - self, tenant_id: TenantId, timeline_id: TimelineId, from_lsn: Lsn, until_lsn: Lsn - ) -> Dict[str, Any]: - res = self.get( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/digest", - params={ - "from_lsn": str(from_lsn), - "until_lsn": str(until_lsn), - }, - ) + def tenant_delete_force(self, tenant_id: TenantId) -> Dict[Any, Any]: + res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) @@ -189,20 +151,6 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn: return self.timeline_status(tenant_id, timeline_id).commit_lsn - def record_safekeeper_info(self, tenant_id: TenantId, timeline_id: TimelineId, body): - res = self.post( - f"http://localhost:{self.port}/v1/record_safekeeper_info/{tenant_id}/{timeline_id}", - json=body, - ) - res.raise_for_status() - - def checkpoint(self, tenant_id: TenantId, timeline_id: TimelineId): - res = self.post( - f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint", - json={}, - ) - res.raise_for_status() - # only_local doesn't remove segments in the remote storage. def timeline_delete( self, tenant_id: TenantId, timeline_id: TimelineId, only_local: bool = False @@ -218,19 +166,71 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def tenant_delete_force(self, tenant_id: TenantId) -> Dict[Any, Any]: - res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") + def debug_dump(self, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + params = params or {} + res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) + res.raise_for_status() + res_json = json.loads(res.text) + assert isinstance(res_json, dict) + return res_json + + def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]: + res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body) res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) return res_json - def get_metrics_str(self) -> str: - """You probably want to use get_metrics() instead.""" - request_result = self.get(f"http://localhost:{self.port}/metrics") - request_result.raise_for_status() - return request_result.text + def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: Dict[str, Any]): + res = self.post( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/copy", + json=body, + ) + res.raise_for_status() - def get_metrics(self) -> SafekeeperMetrics: - res = self.get_metrics_str() - return SafekeeperMetrics(parse_metrics(res)) + def patch_control_file( + self, + tenant_id: TenantId, + timeline_id: TimelineId, + patch: Dict[str, Any], + ) -> Dict[str, Any]: + res = self.patch( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file", + json={ + "updates": patch, + "apply_fields": list(patch.keys()), + }, + ) + res.raise_for_status() + res_json = res.json() + assert isinstance(res_json, dict) + return res_json + + def checkpoint(self, tenant_id: TenantId, timeline_id: TimelineId): + res = self.post( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint", + json={}, + ) + res.raise_for_status() + + def timeline_digest( + self, tenant_id: TenantId, timeline_id: TimelineId, from_lsn: Lsn, until_lsn: Lsn + ) -> Dict[str, Any]: + res = self.get( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/digest", + params={ + "from_lsn": str(from_lsn), + "until_lsn": str(until_lsn), + }, + ) + res.raise_for_status() + res_json = res.json() + assert isinstance(res_json, dict) + return res_json + + def record_safekeeper_info(self, tenant_id: TenantId, timeline_id: TimelineId, body): + res = self.post( + f"http://localhost:{self.port}/v1/record_safekeeper_info/{tenant_id}/{timeline_id}", + json=body, + ) + res.raise_for_status() diff --git a/test_runner/performance/test_wal_backpressure.py b/test_runner/performance/test_wal_backpressure.py index 513ebc74c3..c824e60c29 100644 --- a/test_runner/performance/test_wal_backpressure.py +++ b/test_runner/performance/test_wal_backpressure.py @@ -2,14 +2,14 @@ import statistics import threading import time import timeit -from typing import Any, Callable, List +from typing import Any, Callable, Generator, List import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.common_types import Lsn from fixtures.compare_fixtures import NeonCompare, PgCompare, VanillaCompare from fixtures.log_helper import log -from fixtures.neon_fixtures import DEFAULT_BRANCH_NAME, NeonEnvBuilder, PgBin +from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, flush_ep_to_pageserver from performance.test_perf_pgbench import get_durations_matrix, get_scales_matrix @@ -20,7 +20,7 @@ from performance.test_perf_pgbench import get_durations_matrix, get_scales_matri # For example, to build a `NeonCompare` interface, the corresponding fixture's param should have # a format of `neon_{safekeepers_enable_fsync}`. # Note that, here "_" is used to separate builder parameters. -def pg_compare(request) -> PgCompare: +def pg_compare(request) -> Generator[PgCompare, None, None]: x = request.param.split("_") if x[0] == "vanilla": @@ -28,7 +28,7 @@ def pg_compare(request) -> PgCompare: fixture = request.getfixturevalue("vanilla_compare") assert isinstance(fixture, VanillaCompare) - return fixture + yield fixture else: assert ( len(x) == 2 @@ -47,10 +47,15 @@ def pg_compare(request) -> PgCompare: neon_env_builder.safekeepers_enable_fsync = x[1] == "on" env = neon_env_builder.init_start() - env.neon_cli.create_branch("empty", ancestor_branch_name=DEFAULT_BRANCH_NAME) - branch_name = request.node.name - return NeonCompare(zenbenchmark, env, pg_bin, branch_name) + cmp = NeonCompare(zenbenchmark, env, pg_bin) + + yield cmp + + flush_ep_to_pageserver(env, cmp._pg, cmp.tenant, cmp.timeline) + env.pageserver.http_client().timeline_checkpoint( + cmp.tenant, cmp.timeline, compact=False, wait_until_uploaded=True + ) def start_heavy_write_workload(env: PgCompare, n_tables: int, scale: int, num_iters: int): diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index a7eda73d4c..bb337d9cc1 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -162,7 +162,6 @@ def test_fully_custom_config(positive_env: NeonEnv): "min_resident_size_override": 23, "timeline_get_throttle": { "task_kinds": ["PageRequestHandler"], - "fair": True, "initial": 0, "refill_interval": "1s", "refill_amount": 1000, diff --git a/test_runner/regress/test_auth.py b/test_runner/regress/test_auth.py index 7cb85e3dd1..780c0e1602 100644 --- a/test_runner/regress/test_auth.py +++ b/test_runner/regress/test_auth.py @@ -211,7 +211,7 @@ def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): def check_pageserver(expect_success: bool, **conn_kwargs): check_connection( env.pageserver, - f"pagestream {env.initial_tenant} {env.initial_timeline}", + f"pagestream_v2 {env.initial_tenant} {env.initial_timeline}", expect_success, **conn_kwargs, ) diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index c361efe90a..cd3f405a86 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -173,6 +173,11 @@ def test_backward_compatibility( try: neon_env_builder.num_safekeepers = 3 env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo") + # check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning + ingest_lag_log_line = ( + ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*" + ) + env.pageserver.allowed_errors.append(ingest_lag_log_line) neon_env_builder.start() check_neon_works( @@ -181,6 +186,9 @@ def test_backward_compatibility( sql_dump_path=compatibility_snapshot_dir / "dump.sql", repo_dir=env.repo_dir, ) + + env.pageserver.assert_log_contains(ingest_lag_log_line) + except Exception: if breaking_changes_allowed: pytest.xfail( diff --git a/test_runner/regress/test_config.py b/test_runner/regress/test_config.py index 4bb7df1e6a..2ef28eb94b 100644 --- a/test_runner/regress/test_config.py +++ b/test_runner/regress/test_config.py @@ -1,6 +1,7 @@ +import os from contextlib import closing -from fixtures.neon_fixtures import NeonEnv +from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder # @@ -28,3 +29,45 @@ def test_config(neon_simple_env: NeonEnv): # check that config change was applied assert cur.fetchone() == ("debug1",) + + +# +# Test that reordering of safekeepers does not restart walproposer +# +def test_safekeepers_reconfigure_reorder( + neon_env_builder: NeonEnvBuilder, +): + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.init_start() + env.neon_cli.create_branch("test_safekeepers_reconfigure_reorder") + + endpoint = env.endpoints.create_start("test_safekeepers_reconfigure_reorder") + + old_sks = "" + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("SHOW neon.safekeepers") + res = cur.fetchone() + assert res is not None, "neon.safekeepers GUC is set" + old_sks = res[0] + + # Reorder safekeepers + safekeepers = endpoint.active_safekeepers + safekeepers = safekeepers[1:] + safekeepers[:1] + + endpoint.reconfigure(safekeepers=safekeepers) + + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("SHOW neon.safekeepers") + res = cur.fetchone() + assert res is not None, "neon.safekeepers GUC is set" + new_sks = res[0] + + assert new_sks != old_sks, "GUC changes were applied" + + log_path = os.path.join(endpoint.endpoint_path(), "compute.log") + with open(log_path, "r") as log_file: + logs = log_file.read() + # Check that walproposer was not restarted + assert "restarting walproposer" not in logs diff --git a/test_runner/regress/test_pageserver_getpage_throttle.py b/test_runner/regress/test_pageserver_getpage_throttle.py index 111285b40c..4c9eac5cd7 100644 --- a/test_runner/regress/test_pageserver_getpage_throttle.py +++ b/test_runner/regress/test_pageserver_getpage_throttle.py @@ -1,3 +1,4 @@ +import copy import json import uuid @@ -116,3 +117,58 @@ def test_pageserver_getpage_throttle(neon_env_builder: NeonEnvBuilder, pg_bin: P assert ( duration_secs >= 10 * actual_smgr_query_seconds ), "smgr metrics should not include throttle wait time" + + +throttle_config_with_field_fair_set = { + "task_kinds": ["PageRequestHandler"], + "fair": True, + "initial": 27, + "refill_interval": "43s", + "refill_amount": 23, + "max": 42, +} + + +def assert_throttle_config_with_field_fair_set(conf): + """ + Field `fair` is ignored, so, responses don't contain it + """ + without_fair = copy.deepcopy(throttle_config_with_field_fair_set) + without_fair.pop("fair") + + assert conf == without_fair + + +def test_throttle_fair_config_is_settable_but_ignored_in_mgmt_api(neon_env_builder: NeonEnvBuilder): + """ + To be removed after https://github.com/neondatabase/neon/pull/8539 is rolled out. + """ + env = neon_env_builder.init_start() + ps_http = env.pageserver.http_client() + # with_fair config should still be settable + ps_http.set_tenant_config( + env.initial_tenant, + {"timeline_get_throttle": throttle_config_with_field_fair_set}, + ) + conf = ps_http.tenant_config(env.initial_tenant) + assert_throttle_config_with_field_fair_set(conf.effective_config["timeline_get_throttle"]) + assert_throttle_config_with_field_fair_set( + conf.tenant_specific_overrides["timeline_get_throttle"] + ) + + +def test_throttle_fair_config_is_settable_but_ignored_in_config_toml( + neon_env_builder: NeonEnvBuilder, +): + """ + To be removed after https://github.com/neondatabase/neon/pull/8539 is rolled out. + """ + + def set_tenant_config(ps_cfg): + ps_cfg["tenant_config"] = {"timeline_get_throttle": throttle_config_with_field_fair_set} + + neon_env_builder.pageserver_config_override = set_tenant_config + env = neon_env_builder.init_start() + ps_http = env.pageserver.http_client() + conf = ps_http.tenant_config(env.initial_tenant) + assert_throttle_config_with_field_fair_set(conf.effective_config["timeline_get_throttle"]) diff --git a/test_runner/regress/test_pageserver_layer_rolling.py b/test_runner/regress/test_pageserver_layer_rolling.py index 66b6185aaa..f6404d68ac 100644 --- a/test_runner/regress/test_pageserver_layer_rolling.py +++ b/test_runner/regress/test_pageserver_layer_rolling.py @@ -247,9 +247,10 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder): compaction_period_s = 10 + checkpoint_distance = 1024**3 tenant_conf = { # Large space + time thresholds: effectively disable these limits - "checkpoint_distance": f"{1024 ** 4}", + "checkpoint_distance": f"{checkpoint_distance}", "checkpoint_timeout": "3600s", "compaction_period": f"{compaction_period_s}s", } @@ -269,7 +270,11 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder): for tenant, timeline, last_flush_lsn in last_flush_lsns: http_client = env.pageserver.http_client() initdb_lsn = Lsn(http_client.timeline_detail(tenant, timeline)["initdb_lsn"]) - total_bytes_ingested += last_flush_lsn - initdb_lsn + this_timeline_ingested = last_flush_lsn - initdb_lsn + assert ( + this_timeline_ingested < checkpoint_distance * 0.8 + ), "this test is supposed to fill InMemoryLayer" + total_bytes_ingested += this_timeline_ingested log.info(f"Ingested {total_bytes_ingested} bytes since initdb (vs max dirty {max_dirty_data})") assert total_bytes_ingested > max_dirty_data diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index d128c60a99..1ac881553f 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -19,11 +19,6 @@ def test_read_validation(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "test_read_validation", - # Use protocol version 2, because the code that constructs the V1 messages - # assumes that a primary always wants to read the latest version of a page, - # and therefore doesn't work with the test functions below to read an older - # page version. - config_lines=["neon.protocol_version=2"], ) with closing(endpoint.connect()) as con: @@ -142,11 +137,6 @@ def test_read_validation_neg(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "test_read_validation_neg", - # Use protocol version 2, because the code that constructs the V1 messages - # assumes that a primary always wants to read the latest version of a page, - # and therefore doesn't work with the test functions below to read an older - # page version. - config_lines=["neon.protocol_version=2"], ) with closing(endpoint.connect()) as con: diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index ba8b91e84d..368f60127e 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -1,7 +1,15 @@ +import time + import pytest from fixtures.common_types import Lsn from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnv +from fixtures.neon_fixtures import ( + Endpoint, + NeonEnv, + NeonEnvBuilder, + last_flush_lsn_upload, + tenant_get_shards, +) from fixtures.pageserver.utils import wait_for_last_record_lsn from fixtures.utils import query_scalar @@ -17,7 +25,12 @@ def test_readonly_node(neon_simple_env: NeonEnv): env.neon_cli.create_branch("test_readonly_node", "empty") endpoint_main = env.endpoints.create_start("test_readonly_node") - env.pageserver.allowed_errors.append(".*basebackup .* failed: invalid basebackup lsn.*") + env.pageserver.allowed_errors.extend( + [ + ".*basebackup .* failed: invalid basebackup lsn.*", + ".*page_service.*handle_make_lsn_lease.*.*tried to request a page version that was garbage collected", + ] + ) main_pg_conn = endpoint_main.connect() main_cur = main_pg_conn.cursor() @@ -105,6 +118,103 @@ def test_readonly_node(neon_simple_env: NeonEnv): ) +def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): + """ + Test static endpoint is protected from GC by acquiring and renewing lsn leases. + """ + + neon_env_builder.num_pageservers = 2 + # GC is manual triggered. + env = neon_env_builder.init_start( + initial_tenant_conf={ + # small checkpointing and compaction targets to ensure we generate many upload operations + "checkpoint_distance": f"{128 * 1024}", + "compaction_threshold": "1", + "compaction_target_size": f"{128 * 1024}", + # no PITR horizon, we specify the horizon when we request on-demand GC + "pitr_interval": "0s", + # disable background compaction and GC. We invoke it manually when we want it to happen. + "gc_period": "0s", + "compaction_period": "0s", + # create image layers eagerly, so that GC can remove some layers + "image_creation_threshold": "1", + "image_layer_creation_check_threshold": "0", + # Short lease length to fit test. + "lsn_lease_length": "3s", + }, + initial_tenant_shard_count=2, + ) + + ROW_COUNT = 500 + + def generate_updates_on_main( + env: NeonEnv, + ep_main: Endpoint, + data: int, + start=1, + end=ROW_COUNT, + ) -> Lsn: + """ + Generates some load on main branch that results in some uploads. + """ + with ep_main.cursor() as cur: + cur.execute( + f"INSERT INTO t0 (v0, v1) SELECT g, '{data}' FROM generate_series({start}, {end}) g ON CONFLICT (v0) DO UPDATE SET v1 = EXCLUDED.v1" + ) + cur.execute("VACUUM t0") + last_flush_lsn = last_flush_lsn_upload( + env, ep_main, env.initial_tenant, env.initial_timeline + ) + return last_flush_lsn + + # Insert some records on main branch + with env.endpoints.create_start("main") as ep_main: + with ep_main.cursor() as cur: + cur.execute("CREATE TABLE t0(v0 int primary key, v1 text)") + lsn = None + for i in range(2): + lsn = generate_updates_on_main(env, ep_main, i) + + with env.endpoints.create_start( + branch_name="main", + endpoint_id="static", + lsn=lsn, + ) as ep_static: + with ep_static.cursor() as cur: + cur.execute("SELECT count(*) FROM t0") + assert cur.fetchone() == (ROW_COUNT,) + + time.sleep(3) + + generate_updates_on_main(env, ep_main, i, end=100) + + # Trigger GC + for shard, ps in tenant_get_shards(env, env.initial_tenant): + client = ps.http_client() + gc_result = client.timeline_gc(shard, env.initial_timeline, 0) + log.info(f"{gc_result=}") + + assert ( + gc_result["layers_removed"] == 0 + ), "No layers should be removed, old layers are guarded by leases." + + with ep_static.cursor() as cur: + cur.execute("SELECT count(*) FROM t0") + assert cur.fetchone() == (ROW_COUNT,) + + # Do some update so we can increment latest_gc_cutoff + generate_updates_on_main(env, ep_main, i, end=100) + + # Now trigger GC again, layers should be removed. + time.sleep(4) + for shard, ps in tenant_get_shards(env, env.initial_tenant): + client = ps.http_client() + gc_result = client.timeline_gc(shard, env.initial_timeline, 0) + log.info(f"{gc_result=}") + + assert gc_result["layers_removed"] > 0, "Old layers should be removed after leases expired." + + # Similar test, but with more data, and we force checkpoints def test_timetravel(neon_simple_env: NeonEnv): env = neon_simple_env diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index 448a28dc31..7ee949e8d3 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -1,7 +1,9 @@ +import json from threading import Thread import pytest from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, PgBin, @@ -17,6 +19,8 @@ from fixtures.pageserver.utils import ( from fixtures.remote_storage import RemoteStorageKind, s3_storage from fixtures.utils import run_pg_bench_small, wait_until from requests.exceptions import ReadTimeout +from werkzeug.wrappers.request import Request +from werkzeug.wrappers.response import Response def error_tolerant_delete(ps_http, tenant_id): @@ -322,7 +326,7 @@ def test_tenant_delete_races_timeline_creation(neon_env_builder: NeonEnvBuilder) env.pageserver.stop() -def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder): +def test_tenant_delete_scrubber(pg_bin: PgBin, make_httpserver, neon_env_builder: NeonEnvBuilder): """ Validate that creating and then deleting the tenant both survives the scrubber, and that one can run the scrubber without problems. @@ -347,6 +351,45 @@ def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder) healthy, _ = env.storage_scrubber.scan_metadata() assert healthy + timeline_lsns = { + "tenant_id": f"{tenant_id}", + "timeline_id": f"{timeline_id}", + "timeline_start_lsn": f"{last_flush_lsn}", + "backup_lsn": f"{last_flush_lsn}", + } + + cloud_admin_url = f"http://{make_httpserver.host}:{make_httpserver.port}/" + cloud_admin_token = "" + + def get_branches(request: Request): + # Compare definition with `BranchData` struct + dummy_data = { + "id": "test-branch-id", + "created_at": "", # TODO + "updated_at": "", # TODO + "name": "testbranchname", + "project_id": "test-project-id", + "timeline_id": f"{timeline_id}", + "default": False, + "deleted": False, + "logical_size": 42000, + "physical_size": 42000, + "written_size": 42000, + } + # This test does all its own compute configuration (by passing explicit pageserver ID to Workload functions), + # so we send controller notifications to /dev/null to prevent it fighting the test for control of the compute. + log.info(f"got get_branches request: {request.json}") + return Response(json.dumps(dummy_data), content_type="application/json", status=200) + + make_httpserver.expect_request("/branches", method="GET").respond_with_handler(get_branches) + + healthy, _ = env.storage_scrubber.scan_metadata_safekeeper( + timeline_lsns=[timeline_lsns], + cloud_admin_api_url=cloud_admin_url, + cloud_admin_api_token=cloud_admin_token, + ) + assert healthy + env.start() ps_http = env.pageserver.http_client() ps_http.tenant_delete(tenant_id) @@ -354,3 +397,10 @@ def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder) healthy, _ = env.storage_scrubber.scan_metadata() assert healthy + + healthy, _ = env.storage_scrubber.scan_metadata_safekeeper( + timeline_lsns=[timeline_lsns], + cloud_admin_api_url=cloud_admin_url, + cloud_admin_api_token=cloud_admin_token, + ) + assert healthy diff --git a/test_runner/regress/test_timeline_archive.py b/test_runner/regress/test_timeline_archive.py index b774c7c9fe..7f158ad251 100644 --- a/test_runner/regress/test_timeline_archive.py +++ b/test_runner/regress/test_timeline_archive.py @@ -94,3 +94,29 @@ def test_timeline_archive(neon_simple_env: NeonEnv): timeline_id=parent_timeline_id, state=TimelineArchivalState.ARCHIVED, ) + + # Test that the leaf can't be unarchived + with pytest.raises( + PageserverApiException, + match="ancestor is archived", + ) as exc: + assert timeline_path.exists() + + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=leaf_timeline_id, + state=TimelineArchivalState.UNARCHIVED, + ) + + # Unarchive works for the leaf if the parent gets unarchived first + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=parent_timeline_id, + state=TimelineArchivalState.UNARCHIVED, + ) + + ps_http.timeline_archival_config( + tenant_id=env.initial_tenant, + timeline_id=leaf_timeline_id, + state=TimelineArchivalState.UNARCHIVED, + ) diff --git a/test_runner/regress/test_wal_receiver.py b/test_runner/regress/test_wal_receiver.py index 6582b34218..229d3efd8e 100644 --- a/test_runner/regress/test_wal_receiver.py +++ b/test_runner/regress/test_wal_receiver.py @@ -62,6 +62,12 @@ def test_pageserver_lsn_wait_error_safekeeper_stop(neon_env_builder: NeonEnvBuil elements_to_insert = 1_000_000 expected_timeout_error = f"Timed out while waiting for WAL record at LSN {future_lsn} to arrive" env.pageserver.allowed_errors.append(f".*{expected_timeout_error}.*") + # we configure wait_lsn_timeout to a shorter value than the lagging_wal_timeout / walreceiver_connect_timeout + # => after we run into a timeout and reconnect to a different SK, more time than wait_lsn_timeout has passed + # ==> we log this error + env.pageserver.allowed_errors.append( + ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*" + ) insert_test_elements(env, tenant_id, start=0, count=elements_to_insert) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index b6910406e2..48388a5b59 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit b6910406e2d05a2c94baa2e530ec882733047759 +Subproject commit 48388a5b597c81c09e28c016650a7156b48717a1 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 76063bff63..8aa1ded772 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 76063bff638ccce7afa99fc9037ac51338b9823d +Subproject commit 8aa1ded7726d416ac8e02600aad387a353478fc7 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 8efa089aa7..95132feffe 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 8efa089aa7786381543a4f9efc69b92d43eab8c0 +Subproject commit 95132feffe277ce84309d93a42e9aadfd2cb0437 diff --git a/vendor/revisions.json b/vendor/revisions.json index 50cc99c2f1..319e648488 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,14 +1,14 @@ { "v16": [ "16.4", - "8efa089aa7786381543a4f9efc69b92d43eab8c0" + "95132feffe277ce84309d93a42e9aadfd2cb0437" ], "v15": [ "15.8", - "76063bff638ccce7afa99fc9037ac51338b9823d" + "8aa1ded7726d416ac8e02600aad387a353478fc7" ], "v14": [ "14.13", - "b6910406e2d05a2c94baa2e530ec882733047759" + "48388a5b597c81c09e28c016650a7156b48717a1" ] }