diff --git a/.dockerignore b/.dockerignore index 3c4a748cf7..9e2d2e7108 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,9 +5,7 @@ !Cargo.toml !Makefile !rust-toolchain.toml -!scripts/combine_control_files.py !scripts/ninstall.sh -!vm-cgconfig.conf !docker-compose/run-tests.sh # Directories @@ -17,15 +15,12 @@ !compute_tools/ !control_plane/ !libs/ -!neon_local/ !pageserver/ -!patches/ !pgxn/ !proxy/ !storage_scrubber/ !safekeeper/ !storage_broker/ !storage_controller/ -!trace/ !vendor/postgres-*/ !workspace_hack/ diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 4008cd0d36..330e875d56 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -218,6 +218,9 @@ runs: name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }} # Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/ + # The lack of compatibility snapshot shouldn't fail the job + # (for example if we didn't run the test for non build-and-test workflow) + skip-if-does-not-exist: true - name: Upload test results if: ${{ !cancelled() }} diff --git a/.github/actions/upload/action.yml b/.github/actions/upload/action.yml index edcece7d2b..8a4cfe2eff 100644 --- a/.github/actions/upload/action.yml +++ b/.github/actions/upload/action.yml @@ -7,6 +7,10 @@ inputs: path: description: "A directory or file to upload" required: true + skip-if-does-not-exist: + description: "Allow to skip if path doesn't exist, fail otherwise" + default: false + required: false prefix: description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'" required: false @@ -15,10 +19,12 @@ runs: using: "composite" steps: - name: Prepare artifact + id: prepare-artifact shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst + SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }} run: | mkdir -p $(dirname $ARCHIVE) @@ -33,14 +39,22 @@ runs: elif [ -f ${SOURCE} ]; then time tar -cf ${ARCHIVE} --zstd ${SOURCE} elif ! ls ${SOURCE} > /dev/null 2>&1; then - echo >&2 "${SOURCE} does not exist" - exit 2 + if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then + echo 'SKIPPED=true' >> $GITHUB_OUTPUT + exit 0 + else + echo >&2 "${SOURCE} does not exist" + exit 2 + fi else echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it" exit 3 fi + echo 'SKIPPED=false' >> $GITHUB_OUTPUT + - name: Upload artifact + if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }} shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a759efb56c..e7193cfe19 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -193,16 +193,15 @@ jobs: with: submodules: true -# Disabled for now -# - name: Restore cargo deps cache -# id: cache_cargo -# uses: actions/cache@v4 -# with: -# path: | -# !~/.cargo/registry/src -# ~/.cargo/git/ -# target/ -# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }} + - name: Cache cargo deps + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + !~/.cargo/registry/src + ~/.cargo/git + target + key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust # Some of our rust modules use FFI and need those to be checked - name: Get postgres headers diff --git a/.github/workflows/report-workflow-stats.yml b/.github/workflows/report-workflow-stats.yml new file mode 100644 index 0000000000..6abeff7695 --- /dev/null +++ b/.github/workflows/report-workflow-stats.yml @@ -0,0 +1,41 @@ +name: Report Workflow Stats + +on: + workflow_run: + workflows: + - Add `external` label to issues and PRs created by external users + - Benchmarking + - Build and Test + - Build and Test Locally + - Build build-tools image + - Check Permissions + - Check build-tools image + - Check neon with extra platform builds + - Cloud Regression Test + - Create Release Branch + - Handle `approved-for-ci-run` label + - Lint GitHub Workflows + - Notify Slack channel about upcoming release + - Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region + - Pin build-tools image + - Prepare benchmarking databases by restoring dumps + - Push images to ACR + - Test Postgres client libraries + - Trigger E2E Tests + - cleanup caches by a branch + types: [completed] + +jobs: + gh-workflow-stats: + name: Github Workflow Stats + runs-on: ubuntu-22.04 + permissions: + actions: read + steps: + - name: Export GH Workflow Stats + uses: neondatabase/gh-workflow-stats-action@v0.1.4 + with: + DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} + DB_TABLE: "gh_workflow_stats_neon" + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GH_RUN_ID: ${{ github.event.workflow_run.id }} diff --git a/CODEOWNERS b/CODEOWNERS index 606dbb4e22..f8ed4be816 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,6 @@ /compute_tools/ @neondatabase/control-plane @neondatabase/compute /storage_controller @neondatabase/storage +/storage_scrubber @neondatabase/storage /libs/pageserver_api/ @neondatabase/storage /libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage /libs/remote_storage/ @neondatabase/storage diff --git a/Cargo.lock b/Cargo.lock index 6ae5aac127..5edf5cf7b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -666,34 +666,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum" -version = "0.6.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" -dependencies = [ - "async-trait", - "axum-core 0.3.4", - "bitflags 1.3.2", - "bytes", - "futures-util", - "http 0.2.9", - "http-body 0.4.5", - "hyper 0.14.30", - "itoa", - "matchit 0.7.0", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper 0.1.2", - "tower", - "tower-layer", - "tower-service", -] - [[package]] name = "axum" version = "0.7.5" @@ -701,7 +673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" dependencies = [ "async-trait", - "axum-core 0.4.5", + "axum-core", "base64 0.21.1", "bytes", "futures-util", @@ -731,23 +703,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-core" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http 0.2.9", - "http-body 0.4.5", - "mime", - "rustversion", - "tower-layer", - "tower-service", -] - [[package]] name = "axum-core" version = "0.4.5" @@ -971,7 +926,7 @@ dependencies = [ "clang-sys", "itertools 0.12.1", "log", - "prettyplease 0.2.17", + "prettyplease", "proc-macro2", "quote", "regex", @@ -1865,6 +1820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ "base16ct 0.2.0", + "base64ct", "crypto-bigint 0.5.5", "digest", "ff 0.13.0", @@ -1874,6 +1830,8 @@ dependencies = [ "pkcs8 0.10.2", "rand_core 0.6.4", "sec1 0.7.3", + "serde_json", + "serdect", "subtle", "zeroize", ] @@ -2454,15 +2412,6 @@ dependencies = [ "digest", ] -[[package]] -name = "home" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" -dependencies = [ - "windows-sys 0.52.0", -] - [[package]] name = "hostname" version = "0.4.0" @@ -2657,14 +2606,15 @@ dependencies = [ [[package]] name = "hyper-timeout" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793" dependencies = [ - "hyper 0.14.30", + "hyper 1.4.1", + "hyper-util", "pin-project-lite", "tokio", - "tokio-io-timeout", + "tower-service", ] [[package]] @@ -3470,7 +3420,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", - "prost 0.13.3", + "prost", "reqwest 0.12.4", "thiserror", ] @@ -3483,8 +3433,8 @@ checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9" dependencies = [ "opentelemetry", "opentelemetry_sdk", - "prost 0.13.3", - "tonic 0.12.3", + "prost", + "tonic", ] [[package]] @@ -4090,6 +4040,8 @@ dependencies = [ "bytes", "fallible-iterator", "postgres-protocol", + "serde", + "serde_json", ] [[package]] @@ -4178,16 +4130,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "prettyplease" -version = "0.1.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" -dependencies = [ - "proc-macro2", - "syn 1.0.109", -] - [[package]] name = "prettyplease" version = "0.2.17" @@ -4258,16 +4200,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "prost" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" -dependencies = [ - "bytes", - "prost-derive 0.11.9", -] - [[package]] name = "prost" version = "0.13.3" @@ -4275,42 +4207,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f" dependencies = [ "bytes", - "prost-derive 0.13.3", + "prost-derive", ] [[package]] name = "prost-build" -version = "0.11.9" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" +checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck 0.4.1", - "itertools 0.10.5", - "lazy_static", + "heck 0.5.0", + "itertools 0.12.1", "log", "multimap", + "once_cell", "petgraph", - "prettyplease 0.1.25", - "prost 0.11.9", + "prettyplease", + "prost", "prost-types", "regex", - "syn 1.0.109", + "syn 2.0.52", "tempfile", - "which", -] - -[[package]] -name = "prost-derive" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" -dependencies = [ - "anyhow", - "itertools 0.10.5", - "proc-macro2", - "quote", - "syn 1.0.109", ] [[package]] @@ -4328,11 +4246,11 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.11.9" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" +checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670" dependencies = [ - "prost 0.11.9", + "prost", ] [[package]] @@ -5094,6 +5012,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls" +version = "0.23.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.2" @@ -5119,6 +5052,19 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.1", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -5194,6 +5140,7 @@ dependencies = [ "fail", "futures", "hex", + "http 1.1.0", "humantime", "hyper 0.14.30", "metrics", @@ -5314,6 +5261,7 @@ dependencies = [ "der 0.7.8", "generic-array", "pkcs8 0.10.2", + "serdect", "subtle", "zeroize", ] @@ -5568,6 +5516,16 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "serdect" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177" +dependencies = [ + "base16ct 0.2.0", + "serde", +] + [[package]] name = "sha1" version = "0.10.5" @@ -5750,19 +5708,22 @@ version = "0.1.0" dependencies = [ "anyhow", "async-stream", + "bytes", "clap", "const_format", "futures", "futures-core", "futures-util", + "http-body-util", "humantime", - "hyper 0.14.30", + "hyper 1.4.1", + "hyper-util", "metrics", "once_cell", "parking_lot 0.12.1", - "prost 0.11.9", + "prost", "tokio", - "tonic 0.9.2", + "tonic", "tonic-build", "tracing", "utils", @@ -6306,6 +6267,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.7", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.16" @@ -6397,29 +6369,30 @@ dependencies = [ [[package]] name = "tonic" -version = "0.9.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" +checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" dependencies = [ "async-stream", "async-trait", - "axum 0.6.20", - "base64 0.21.1", + "axum", + "base64 0.22.1", "bytes", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.9", - "http-body 0.4.5", - "hyper 0.14.30", + "h2 0.4.4", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.4.1", "hyper-timeout", + "hyper-util", "percent-encoding", "pin-project", - "prost 0.11.9", - "rustls-native-certs 0.6.2", - "rustls-pemfile 1.0.2", + "prost", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.1.1", + "socket2", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.26.0", "tokio-stream", "tower", "tower-layer", @@ -6428,37 +6401,17 @@ dependencies = [ ] [[package]] -name = "tonic" +name = "tonic-build" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "percent-encoding", - "pin-project", - "prost 0.13.3", - "tokio-stream", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tonic-build" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07" -dependencies = [ - "prettyplease 0.1.25", + "prettyplease", "proc-macro2", "prost-build", + "prost-types", "quote", - "syn 1.0.109", + "syn 2.0.52", ] [[package]] @@ -6864,7 +6817,7 @@ name = "vm_monitor" version = "0.1.0" dependencies = [ "anyhow", - "axum 0.7.5", + "axum", "cgroups-rs", "clap", "futures", @@ -7095,18 +7048,6 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - [[package]] name = "whoami" version = "1.5.1" @@ -7335,9 +7276,10 @@ version = "0.1.0" dependencies = [ "ahash", "anyhow", + "axum", + "axum-core", "base64 0.21.1", "base64ct", - "bitflags 2.4.1", "bytes", "camino", "cc", @@ -7365,7 +7307,6 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "indexmap 1.9.3", - "itertools 0.10.5", "itertools 0.12.1", "lazy_static", "libc", @@ -7377,15 +7318,16 @@ dependencies = [ "num-traits", "once_cell", "parquet", + "postgres-types", + "prettyplease", "proc-macro2", - "prost 0.11.9", + "prost", "quote", "rand 0.8.5", "regex", "regex-automata 0.4.3", "regex-syntax 0.8.2", "reqwest 0.12.4", - "rustls 0.21.11", "scopeguard", "serde", "serde_json", @@ -7401,9 +7343,11 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-rustls 0.24.0", + "tokio-postgres", + "tokio-stream", "tokio-util", "toml_edit", + "tonic", "tower", "tracing", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index ed7dda235a..dde80f5020 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,7 +130,7 @@ pbkdf2 = { version = "0.12.1", features = ["simple", "std"] } pin-project-lite = "0.2" procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency -prost = "0.11" +prost = "0.13" rand = "0.8" redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" @@ -178,7 +178,7 @@ tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" -tonic = {version = "0.9", features = ["tls", "tls-roots"]} +tonic = {version = "0.12.3", features = ["tls", "tls-roots"]} tower-service = "0.3.2" tracing = "0.1" tracing-error = "0.2" @@ -246,7 +246,7 @@ criterion = "0.5.1" rcgen = "0.12" rstest = "0.18" camino-tempfile = "1.0.2" -tonic-build = "0.9" +tonic-build = "0.12" [patch.crates-io] diff --git a/Makefile b/Makefile index b9bb1c147d..5e227ed3f5 100644 --- a/Makefile +++ b/Makefile @@ -168,27 +168,27 @@ postgres-check-%: postgres-% neon-pg-ext-%: postgres-% +@echo "Compiling neon $*" mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install +@echo "Compiling neon_walredo $*" mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_walredo/Makefile install +@echo "Compiling neon_rmgr $*" mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_rmgr/Makefile install +@echo "Compiling neon_test_utils $*" mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install +@echo "Compiling neon_utils $*" mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \ -f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install @@ -220,7 +220,7 @@ neon-pg-clean-ext-%: walproposer-lib: neon-pg-ext-v17 +@echo "Compiling walproposer-lib" mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \ -C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \ -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib @@ -333,7 +333,7 @@ postgres-%-pgindent: postgres-%-pg-bsd-indent postgres-%-typedefs.list # Indent pxgn/neon. .PHONY: neon-pgindent neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17 - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \ + $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \ FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \ INDENT=$(POSTGRES_INSTALL_DIR)/build/v17/src/tools/pg_bsd_indent/pg_bsd_indent \ PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \ diff --git a/compute/Dockerfile.compute-node b/compute/Dockerfile.compute-node index 5332b9ca1f..15afb9897f 100644 --- a/compute/Dockerfile.compute-node +++ b/compute/Dockerfile.compute-node @@ -109,13 +109,30 @@ RUN apt update && \ libcgal-dev libgdal-dev libgmp-dev libmpfr-dev libopenscenegraph-dev libprotobuf-c-dev \ protobuf-c-compiler xsltproc + +# Postgis 3.5.0 requires SFCGAL 1.4+ +# +# It would be nice to update all versions together, but we must solve the SFCGAL dependency first. # SFCGAL > 1.3 requires CGAL > 5.2, Bullseye's libcgal-dev is 5.2 -RUN case "${PG_VERSION}" in "v17") \ - mkdir -p /sfcgal && \ - echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \ +# and also we must check backward compatibility with older versions of PostGIS. +# +# Use new version only for v17 +RUN case "${PG_VERSION}" in \ + "v17") \ + export SFCGAL_VERSION=1.4.1 \ + export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \ + ;; \ + "v14" | "v15" | "v16") \ + export SFCGAL_VERSION=1.3.10 \ + export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \ + ;; \ + *) \ + echo "unexpected PostgreSQL version" && exit 1 \ + ;; \ esac && \ - wget https://gitlab.com/Oslandia/SFCGAL/-/archive/v1.3.10/SFCGAL-v1.3.10.tar.gz -O SFCGAL.tar.gz && \ - echo "4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 SFCGAL.tar.gz" | sha256sum --check && \ + mkdir -p /sfcgal && \ + wget https://gitlab.com/sfcgal/SFCGAL/-/archive/v${SFCGAL_VERSION}/SFCGAL-v${SFCGAL_VERSION}.tar.gz -O SFCGAL.tar.gz && \ + echo "${SFCGAL_CHECKSUM} SFCGAL.tar.gz" | sha256sum --check && \ mkdir sfcgal-src && cd sfcgal-src && tar xzf ../SFCGAL.tar.gz --strip-components=1 -C . && \ cmake -DCMAKE_BUILD_TYPE=Release . && make -j $(getconf _NPROCESSORS_ONLN) && \ DESTDIR=/sfcgal make install -j $(getconf _NPROCESSORS_ONLN) && \ @@ -123,15 +140,27 @@ RUN case "${PG_VERSION}" in "v17") \ ENV PATH="/usr/local/pgsql/bin:$PATH" -RUN case "${PG_VERSION}" in "v17") \ - echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \ +# Postgis 3.5.0 supports v17 +RUN case "${PG_VERSION}" in \ + "v17") \ + export POSTGIS_VERSION=3.5.0 \ + export POSTGIS_CHECKSUM=ca698a22cc2b2b3467ac4e063b43a28413f3004ddd505bdccdd74c56a647f510 \ + ;; \ + "v14" | "v15" | "v16") \ + export POSTGIS_VERSION=3.3.3 \ + export POSTGIS_CHECKSUM=74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 \ + ;; \ + *) \ + echo "unexpected PostgreSQL version" && exit 1 \ + ;; \ esac && \ - wget https://download.osgeo.org/postgis/source/postgis-3.3.3.tar.gz -O postgis.tar.gz && \ - echo "74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 postgis.tar.gz" | sha256sum --check && \ + wget https://download.osgeo.org/postgis/source/postgis-${POSTGIS_VERSION}.tar.gz -O postgis.tar.gz && \ + echo "${POSTGIS_CHECKSUM} postgis.tar.gz" | sha256sum --check && \ mkdir postgis-src && cd postgis-src && tar xzf ../postgis.tar.gz --strip-components=1 -C . && \ find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\ ./autogen.sh && \ ./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \ + make -j $(getconf _NPROCESSORS_ONLN) && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ cd extensions/postgis && \ make clean && \ @@ -152,11 +181,27 @@ RUN case "${PG_VERSION}" in "v17") \ cp /usr/local/pgsql/share/extension/address_standardizer.control /extensions/postgis && \ cp /usr/local/pgsql/share/extension/address_standardizer_data_us.control /extensions/postgis -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ +# Uses versioned libraries, i.e. libpgrouting-3.4 +# and may introduce function signature changes between releases +# i.e. release 3.5.0 has new signature for pg_dijkstra function +# +# Use new version only for v17 +# last release v3.6.2 - Mar 30, 2024 +RUN case "${PG_VERSION}" in \ + "v17") \ + export PGROUTING_VERSION=3.6.2 \ + export PGROUTING_CHECKSUM=f4a1ed79d6f714e52548eca3bb8e5593c6745f1bde92eb5fb858efd8984dffa2 \ + ;; \ + "v14" | "v15" | "v16") \ + export PGROUTING_VERSION=3.4.2 \ + export PGROUTING_CHECKSUM=cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e \ + ;; \ + *) \ + echo "unexpected PostgreSQL version" && exit 1 \ + ;; \ esac && \ - wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouting.tar.gz && \ - echo "cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e pgrouting.tar.gz" | sha256sum --check && \ + wget https://github.com/pgRouting/pgrouting/archive/v${PGROUTING_VERSION}.tar.gz -O pgrouting.tar.gz && \ + echo "${PGROUTING_CHECKSUM} pgrouting.tar.gz" | sha256sum --check && \ mkdir pgrouting-src && cd pgrouting-src && tar xzf ../pgrouting.tar.gz --strip-components=1 -C . && \ mkdir build && cd build && \ cmake -DCMAKE_BUILD_TYPE=Release .. && \ @@ -215,10 +260,9 @@ FROM build-deps AS h3-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - mkdir -p /h3/usr/ && \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ +# not version-specific +# last release v4.1.0 - Jan 18, 2023 +RUN mkdir -p /h3/usr/ && \ wget https://github.com/uber/h3/archive/refs/tags/v4.1.0.tar.gz -O h3.tar.gz && \ echo "ec99f1f5974846bde64f4513cf8d2ea1b8d172d2218ab41803bf6a63532272bc h3.tar.gz" | sha256sum --check && \ mkdir h3-src && cd h3-src && tar xzf ../h3.tar.gz --strip-components=1 -C . && \ @@ -229,10 +273,9 @@ RUN case "${PG_VERSION}" in "v17") \ cp -R /h3/usr / && \ rm -rf build -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \ +# not version-specific +# last release v4.1.3 - Jul 26, 2023 +RUN wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \ echo "5c17f09a820859ffe949f847bebf1be98511fb8f1bd86f94932512c00479e324 h3-pg.tar.gz" | sha256sum --check && \ mkdir h3-pg-src && cd h3-pg-src && tar xzf ../h3-pg.tar.gz --strip-components=1 -C . && \ export PATH="/usr/local/pgsql/bin:$PATH" && \ @@ -251,11 +294,10 @@ FROM build-deps AS unit-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -O postgresql-unit.tar.gz && \ - echo "411d05beeb97e5a4abf17572bfcfbb5a68d98d1018918feff995f6ee3bb03e79 postgresql-unit.tar.gz" | sha256sum --check && \ +# not version-specific +# last release 7.9 - Sep 15, 2024 +RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.9.tar.gz -O postgresql-unit.tar.gz && \ + echo "e46de6245dcc8b2c2ecf29873dbd43b2b346773f31dd5ce4b8315895a052b456 postgresql-unit.tar.gz" | sha256sum --check && \ mkdir postgresql-unit-src && cd postgresql-unit-src && tar xzf ../postgresql-unit.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ @@ -302,12 +344,10 @@ FROM build-deps AS pgjwt-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -# 9742dab1b2f297ad3811120db7b21451bca2d3c9 made on 13/11/2021 -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/michelp/pgjwt/archive/9742dab1b2f297ad3811120db7b21451bca2d3c9.tar.gz -O pgjwt.tar.gz && \ - echo "cfdefb15007286f67d3d45510f04a6a7a495004be5b3aecb12cda667e774203f pgjwt.tar.gz" | sha256sum --check && \ +# not version-specific +# doesn't use releases, last commit f3d82fd - Mar 2, 2023 +RUN wget https://github.com/michelp/pgjwt/archive/f3d82fd30151e754e19ce5d6a06c71c20689ce3d.tar.gz -O pgjwt.tar.gz && \ + echo "dae8ed99eebb7593b43013f6532d772b12dfecd55548d2673f2dfd0163f6d2b9 pgjwt.tar.gz" | sha256sum --check && \ mkdir pgjwt-src && cd pgjwt-src && tar xzf ../pgjwt.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgjwt.control @@ -342,10 +382,9 @@ FROM build-deps AS pg-hashids-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \ +# not version-specific +# last release v1.2.1 -Jan 12, 2018 +RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \ echo "74576b992d9277c92196dd8d816baa2cc2d8046fe102f3dcd7f3c3febed6822a pg_hashids.tar.gz" | sha256sum --check && \ mkdir pg_hashids-src && cd pg_hashids-src && tar xzf ../pg_hashids.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \ @@ -405,10 +444,9 @@ FROM build-deps AS ip4r-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \ +# not version-specific +# last release v2.4.2 - Jul 29, 2023 +RUN wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \ echo "0f7b1f159974f49a47842a8ab6751aecca1ed1142b6d5e38d81b064b2ead1b4b ip4r.tar.gz" | sha256sum --check && \ mkdir ip4r-src && cd ip4r-src && tar xzf ../ip4r.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ @@ -425,10 +463,9 @@ FROM build-deps AS prefix-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \ +# not version-specific +# last release v1.2.10 - Jul 5, 2023 +RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \ echo "4342f251432a5f6fb05b8597139d3ccde8dcf87e8ca1498e7ee931ca057a8575 prefix.tar.gz" | sha256sum --check && \ mkdir prefix-src && cd prefix-src && tar xzf ../prefix.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ @@ -445,10 +482,9 @@ FROM build-deps AS hll-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions are not supported yet. Quit" && exit 0;; \ - esac && \ - wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \ +# not version-specific +# last release v2.18 - Aug 29, 2023 +RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \ echo "e2f55a6f4c4ab95ee4f1b4a2b73280258c5136b161fe9d059559556079694f0e hll.tar.gz" | sha256sum --check && \ mkdir hll-src && cd hll-src && tar xzf ../hll.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \ @@ -659,11 +695,10 @@ FROM build-deps AS pg-roaringbitmap-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ +# not version-specific +# last release v0.5.4 - Jun 28, 2022 ENV PATH="/usr/local/pgsql/bin/:$PATH" -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 extensions is not supported yet by pg_roaringbitmap. Quit" && exit 0;; \ - esac && \ - wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \ +RUN wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \ echo "b75201efcb1c2d1b014ec4ae6a22769cc7a224e6e406a587f5784a37b6b5a2aa pg_roaringbitmap.tar.gz" | sha256sum --check && \ mkdir pg_roaringbitmap-src && cd pg_roaringbitmap-src && tar xzf ../pg_roaringbitmap.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) && \ @@ -680,12 +715,27 @@ FROM build-deps AS pg-semver-pg-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ +# Release 0.40.0 breaks backward compatibility with previous versions +# see release note https://github.com/theory/pg-semver/releases/tag/v0.40.0 +# Use new version only for v17 +# +# last release v0.40.0 - Jul 22, 2024 ENV PATH="/usr/local/pgsql/bin/:$PATH" -RUN case "${PG_VERSION}" in "v17") \ - echo "v17 is not supported yet by pg_semver. Quit" && exit 0;; \ +RUN case "${PG_VERSION}" in \ + "v17") \ + export SEMVER_VERSION=0.40.0 \ + export SEMVER_CHECKSUM=3e50bcc29a0e2e481e7b6d2bc937cadc5f5869f55d983b5a1aafeb49f5425cfc \ + ;; \ + "v14" | "v15" | "v16") \ + export SEMVER_VERSION=0.32.1 \ + export SEMVER_CHECKSUM=fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 \ + ;; \ + *) \ + echo "unexpected PostgreSQL version" && exit 1 \ + ;; \ esac && \ - wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O pg_semver.tar.gz && \ - echo "fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 pg_semver.tar.gz" | sha256sum --check && \ + wget https://github.com/theory/pg-semver/archive/refs/tags/v${SEMVER_VERSION}.tar.gz -O pg_semver.tar.gz && \ + echo "${SEMVER_CHECKSUM} pg_semver.tar.gz" | sha256sum --check && \ mkdir pg_semver-src && cd pg_semver-src && tar xzf ../pg_semver.tar.gz --strip-components=1 -C . && \ make -j $(getconf _NPROCESSORS_ONLN) && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 109d315d67..284db005c8 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -402,8 +402,7 @@ fn start_postgres( ) -> Result<(Option, StartPostgresResult)> { // We got all we need, update the state. let mut state = compute.state.lock().unwrap(); - state.status = ComputeStatus::Init; - compute.state_changed.notify_all(); + state.set_status(ComputeStatus::Init, &compute.state_changed); info!( "running compute with features: {:?}", diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index ba7b4f37df..285be56264 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -109,6 +109,18 @@ impl ComputeState { metrics: ComputeMetrics::default(), } } + + pub fn set_status(&mut self, status: ComputeStatus, state_changed: &Condvar) { + let prev = self.status; + info!("Changing compute status from {} to {}", prev, status); + self.status = status; + state_changed.notify_all(); + } + + pub fn set_failed_status(&mut self, err: anyhow::Error, state_changed: &Condvar) { + self.error = Some(format!("{err:?}")); + self.set_status(ComputeStatus::Failed, state_changed); + } } impl Default for ComputeState { @@ -303,15 +315,12 @@ impl ComputeNode { pub fn set_status(&self, status: ComputeStatus) { let mut state = self.state.lock().unwrap(); - state.status = status; - self.state_changed.notify_all(); + state.set_status(status, &self.state_changed); } pub fn set_failed_status(&self, err: anyhow::Error) { let mut state = self.state.lock().unwrap(); - state.error = Some(format!("{err:?}")); - state.status = ComputeStatus::Failed; - self.state_changed.notify_all(); + state.set_failed_status(err, &self.state_changed); } pub fn get_status(&self) -> ComputeStatus { @@ -1475,6 +1484,28 @@ LIMIT 100", info!("Pageserver config changed"); } } + + // Gather info about installed extensions + pub fn get_installed_extensions(&self) -> Result<()> { + let connstr = self.connstr.clone(); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to create runtime"); + let result = rt + .block_on(crate::installed_extensions::get_installed_extensions( + connstr, + )) + .expect("failed to get installed extensions"); + + info!( + "{}", + serde_json::to_string(&result).expect("failed to serialize extensions list") + ); + + Ok(()) + } } pub fn forward_termination_signal() { diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index 7bd0e4938d..a2043529a1 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -24,8 +24,7 @@ fn configurator_main_loop(compute: &Arc) { // Re-check the status after waking up if state.status == ComputeStatus::ConfigurationPending { info!("got configuration request"); - state.status = ComputeStatus::Configuration; - compute.state_changed.notify_all(); + state.set_status(ComputeStatus::Configuration, &compute.state_changed); drop(state); let mut new_status = ComputeStatus::Failed; diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 43d29402bc..79e6158081 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -165,6 +165,32 @@ async fn routes(req: Request, compute: &Arc) -> Response { + info!("serving /installed_extensions GET request"); + let status = compute.get_status(); + if status != ComputeStatus::Running { + let msg = format!( + "invalid compute status for extensions request: {:?}", + status + ); + error!(msg); + return Response::new(Body::from(msg)); + } + + let connstr = compute.connstr.clone(); + let res = crate::installed_extensions::get_installed_extensions(connstr).await; + match res { + Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())), + Err(e) => render_json_error( + &format!("could not get list of installed extensions: {}", e), + StatusCode::INTERNAL_SERVER_ERROR, + ), + } + } + // download extension files from remote extension storage on demand (&Method::POST, route) if route.starts_with("/extension_server/") => { info!("serving {:?} POST request", route); @@ -288,8 +314,7 @@ async fn handle_configure_request( return Err((msg, StatusCode::PRECONDITION_FAILED)); } state.pspec = Some(parsed_spec); - state.status = ComputeStatus::ConfigurationPending; - compute.state_changed.notify_all(); + state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed); drop(state); info!("set new spec and notified waiters"); } @@ -362,15 +387,15 @@ async fn handle_terminate_request(compute: &Arc) -> Result<(), (Str } if state.status != ComputeStatus::Empty && state.status != ComputeStatus::Running { let msg = format!( - "invalid compute status for termination request: {:?}", - state.status.clone() + "invalid compute status for termination request: {}", + state.status ); return Err((msg, StatusCode::PRECONDITION_FAILED)); } - state.status = ComputeStatus::TerminationPending; - compute.state_changed.notify_all(); + state.set_status(ComputeStatus::TerminationPending, &compute.state_changed); drop(state); } + forward_termination_signal(); info!("sent signal and notified waiters"); @@ -384,7 +409,8 @@ async fn handle_terminate_request(compute: &Arc) -> Result<(), (Str while state.status != ComputeStatus::Terminated { state = c.state_changed.wait(state).unwrap(); info!( - "waiting for compute to become Terminated, current status: {:?}", + "waiting for compute to become {}, current status: {:?}", + ComputeStatus::Terminated, state.status ); } diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index b0ddaeae2b..e9fa66b323 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -53,6 +53,20 @@ paths: schema: $ref: "#/components/schemas/ComputeInsights" + /installed_extensions: + get: + tags: + - Info + summary: Get installed extensions. + description: "" + operationId: getInstalledExtensions + responses: + 200: + description: List of installed extensions + content: + application/json: + schema: + $ref: "#/components/schemas/InstalledExtensions" /info: get: tags: @@ -395,6 +409,24 @@ components: - configuration example: running + InstalledExtensions: + type: object + properties: + extensions: + description: Contains list of installed extensions. + type: array + items: + type: object + properties: + extname: + type: string + versions: + type: array + items: + type: string + n_databases: + type: integer + # # Errors # diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs new file mode 100644 index 0000000000..3d8b22a8a3 --- /dev/null +++ b/compute_tools/src/installed_extensions.rs @@ -0,0 +1,80 @@ +use compute_api::responses::{InstalledExtension, InstalledExtensions}; +use std::collections::HashMap; +use std::collections::HashSet; +use url::Url; + +use anyhow::Result; +use postgres::{Client, NoTls}; +use tokio::task; + +/// We don't reuse get_existing_dbs() just for code clarity +/// and to make database listing query here more explicit. +/// +/// Limit the number of databases to 500 to avoid excessive load. +fn list_dbs(client: &mut Client) -> Result> { + // `pg_database.datconnlimit = -2` means that the database is in the + // invalid state + let databases = client + .query( + "SELECT datname FROM pg_catalog.pg_database + WHERE datallowconn + AND datconnlimit <> - 2 + LIMIT 500", + &[], + )? + .iter() + .map(|row| { + let db: String = row.get("datname"); + db + }) + .collect(); + + Ok(databases) +} + +/// Connect to every database (see list_dbs above) and get the list of installed extensions. +/// Same extension can be installed in multiple databases with different versions, +/// we only keep the highest and lowest version across all databases. +pub async fn get_installed_extensions(connstr: Url) -> Result { + let mut connstr = connstr.clone(); + + task::spawn_blocking(move || { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + let databases: Vec = list_dbs(&mut client)?; + + let mut extensions_map: HashMap = HashMap::new(); + for db in databases.iter() { + connstr.set_path(db); + let mut db_client = Client::connect(connstr.as_str(), NoTls)?; + let extensions: Vec<(String, String)> = db_client + .query( + "SELECT extname, extversion FROM pg_catalog.pg_extension;", + &[], + )? + .iter() + .map(|row| (row.get("extname"), row.get("extversion"))) + .collect(); + + for (extname, v) in extensions.iter() { + let version = v.to_string(); + extensions_map + .entry(extname.to_string()) + .and_modify(|e| { + e.versions.insert(version.clone()); + // count the number of databases where the extension is installed + e.n_databases += 1; + }) + .or_insert(InstalledExtension { + extname: extname.to_string(), + versions: HashSet::from([version.clone()]), + n_databases: 1, + }); + } + } + + Ok(InstalledExtensions { + extensions: extensions_map.values().cloned().collect(), + }) + }) + .await? +} diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 477f423aa2..d27ae58fa2 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -15,6 +15,7 @@ pub mod catalog; pub mod compute; pub mod disk_quota; pub mod extension_server; +pub mod installed_extensions; pub mod local_proxy; pub mod lsn_lease; mod migration; diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index aa9405d28d..73f3d1006a 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::fs::File; use std::path::Path; use std::str::FromStr; @@ -189,6 +190,15 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> { let mut xact = client.transaction()?; let existing_roles: Vec = get_existing_roles(&mut xact)?; + let mut jwks_roles = HashSet::new(); + if let Some(local_proxy) = &spec.local_proxy_config { + for jwks_setting in local_proxy.jwks.iter().flatten() { + for role_name in &jwks_setting.role_names { + jwks_roles.insert(role_name.clone()); + } + } + } + // Print a list of existing Postgres roles (only in debug mode) if span_enabled!(Level::INFO) { let mut vec = Vec::new(); @@ -308,6 +318,9 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> { "CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser", name.pg_quote() ); + if jwks_roles.contains(name.as_str()) { + query = format!("CREATE ROLE {}", name.pg_quote()); + } info!("running role create query: '{}'", &query); query.push_str(&role.to_pg_options()); xact.execute(query.as_str(), &[])?; diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index d05d625b0a..5023fce003 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -1,5 +1,8 @@ //! Structs representing the JSON formats used in the compute_ctl's HTTP API. +use std::collections::HashSet; +use std::fmt::Display; + use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize, Serializer}; @@ -58,6 +61,21 @@ pub enum ComputeStatus { Terminated, } +impl Display for ComputeStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ComputeStatus::Empty => f.write_str("empty"), + ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"), + ComputeStatus::Init => f.write_str("init"), + ComputeStatus::Running => f.write_str("running"), + ComputeStatus::Configuration => f.write_str("configuration"), + ComputeStatus::Failed => f.write_str("failed"), + ComputeStatus::TerminationPending => f.write_str("termination-pending"), + ComputeStatus::Terminated => f.write_str("terminated"), + } + } +} + fn rfc3339_serialize(x: &Option>, s: S) -> Result where S: Serializer, @@ -138,3 +156,15 @@ pub enum ControlPlaneComputeStatus { // should be able to start with provided spec. Attached, } + +#[derive(Clone, Debug, Default, Serialize)] +pub struct InstalledExtension { + pub extname: String, + pub versions: HashSet, + pub n_databases: u32, // Number of databases using this extension +} + +#[derive(Clone, Debug, Default, Serialize)] +pub struct InstalledExtensions { + pub extensions: Vec, +} diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 105c8a50d3..24474d4840 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -104,8 +104,7 @@ pub struct ConfigToml { pub image_compression: ImageCompressionAlgorithm, pub ephemeral_bytes_per_memory_kb: usize, pub l0_flush: Option, - pub virtual_file_direct_io: crate::models::virtual_file::DirectIoMode, - pub io_buffer_alignment: usize, + pub virtual_file_io_mode: Option, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -388,10 +387,7 @@ impl Default for ConfigToml { image_compression: (DEFAULT_IMAGE_COMPRESSION), ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), l0_flush: None, - virtual_file_direct_io: crate::models::virtual_file::DirectIoMode::default(), - - io_buffer_alignment: DEFAULT_IO_BUFFER_ALIGNMENT, - + virtual_file_io_mode: None, tenant_config: TenantConfigToml::default(), } } diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 45abda0ad8..3ec9cac2c3 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -972,8 +972,6 @@ pub struct TopTenantShardsResponse { } pub mod virtual_file { - use std::path::PathBuf; - #[derive( Copy, Clone, @@ -994,50 +992,45 @@ pub mod virtual_file { } /// Direct IO modes for a pageserver. - #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)] - #[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)] - pub enum DirectIoMode { - /// Direct IO disabled (uses usual buffered IO). - #[default] - Disabled, - /// Direct IO disabled (performs checks and perf simulations). - Evaluate { - /// Alignment check level - alignment_check: DirectIoAlignmentCheckLevel, - /// Latency padded for performance simulation. - latency_padding: DirectIoLatencyPadding, - }, - /// Direct IO enabled. - Enabled { - /// Actions to perform on alignment error. - on_alignment_error: DirectIoOnAlignmentErrorAction, - }, + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Hash, + strum_macros::EnumString, + strum_macros::Display, + serde_with::DeserializeFromStr, + serde_with::SerializeDisplay, + Debug, + )] + #[strum(serialize_all = "kebab-case")] + #[repr(u8)] + pub enum IoMode { + /// Uses buffered IO. + Buffered, + /// Uses direct IO, error out if the operation fails. + #[cfg(target_os = "linux")] + Direct, } - #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)] - #[serde(rename_all = "kebab-case")] - pub enum DirectIoAlignmentCheckLevel { - #[default] - Error, - Log, - None, + impl IoMode { + pub const fn preferred() -> Self { + Self::Buffered + } } - #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)] - #[serde(rename_all = "kebab-case")] - pub enum DirectIoOnAlignmentErrorAction { - Error, - #[default] - FallbackToBuffered, - } + impl TryFrom for IoMode { + type Error = u8; - #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)] - #[serde(tag = "type", rename_all = "kebab-case")] - pub enum DirectIoLatencyPadding { - /// Pad virtual file operations with IO to a fake file. - FakeFileRW { path: PathBuf }, - #[default] - None, + fn try_from(value: u8) -> Result { + Ok(match value { + v if v == (IoMode::Buffered as u8) => IoMode::Buffered, + #[cfg(target_os = "linux")] + v if v == (IoMode::Direct as u8) => IoMode::Direct, + x => return Err(x), + }) + } } } diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index e113a987a5..f98d16789c 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -496,26 +496,12 @@ impl RemoteStorage for AzureBlobStorage { builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string())) } - self.download_for_builder(builder, cancel).await - } - - async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result { - let blob_client = self.client.blob_client(self.relative_path_to_name(from)); - - let mut builder = blob_client.get(); - - let range: Range = if let Some(end_exclusive) = end_exclusive { - (start_inclusive..end_exclusive).into() - } else { - (start_inclusive..).into() - }; - builder = builder.range(range); + if let Some((start, end)) = opts.byte_range() { + builder = builder.range(match end { + Some(end) => Range::Range(start..end), + None => Range::RangeFrom(start..), + }); + } self.download_for_builder(builder, cancel).await } diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 0ff0f1c878..c6466237bf 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -19,7 +19,8 @@ mod simulate_failures; mod support; use std::{ - collections::HashMap, fmt::Debug, num::NonZeroU32, pin::Pin, sync::Arc, time::SystemTime, + collections::HashMap, fmt::Debug, num::NonZeroU32, ops::Bound, pin::Pin, sync::Arc, + time::SystemTime, }; use anyhow::Context; @@ -162,11 +163,60 @@ pub struct Listing { } /// Options for downloads. The default value is a plain GET. -#[derive(Default)] pub struct DownloadOpts { /// If given, returns [`DownloadError::Unmodified`] if the object still has /// the same ETag (using If-None-Match). pub etag: Option, + /// The start of the byte range to download, or unbounded. + pub byte_start: Bound, + /// The end of the byte range to download, or unbounded. Must be after the + /// start bound. + pub byte_end: Bound, +} + +impl Default for DownloadOpts { + fn default() -> Self { + Self { + etag: Default::default(), + byte_start: Bound::Unbounded, + byte_end: Bound::Unbounded, + } + } +} + +impl DownloadOpts { + /// Returns the byte range with inclusive start and exclusive end, or None + /// if unbounded. + pub fn byte_range(&self) -> Option<(u64, Option)> { + if self.byte_start == Bound::Unbounded && self.byte_end == Bound::Unbounded { + return None; + } + let start = match self.byte_start { + Bound::Excluded(i) => i + 1, + Bound::Included(i) => i, + Bound::Unbounded => 0, + }; + let end = match self.byte_end { + Bound::Excluded(i) => Some(i), + Bound::Included(i) => Some(i + 1), + Bound::Unbounded => None, + }; + if let Some(end) = end { + assert!(start < end, "range end {end} at or before start {start}"); + } + Some((start, end)) + } + + /// Returns the byte range as an RFC 2616 Range header value with inclusive + /// bounds, or None if unbounded. + pub fn byte_range_header(&self) -> Option { + self.byte_range() + .map(|(start, end)| (start, end.map(|end| end - 1))) // make end inclusive + .map(|(start, end)| match end { + Some(end) => format!("bytes={start}-{end}"), + None => format!("bytes={start}-"), + }) + } } /// Storage (potentially remote) API to manage its state. @@ -257,21 +307,6 @@ pub trait RemoteStorage: Send + Sync + 'static { cancel: &CancellationToken, ) -> Result; - /// Streams a given byte range of the remote storage entry contents. - /// - /// The returned download stream will obey initial timeout and cancellation signal by erroring - /// on whichever happens first. Only one of the reasons will fail the stream, which is usually - /// enough for `tokio::io::copy_buf` usage. If needed the error can be filtered out. - /// - /// Returns the metadata, if any was stored with the file previously. - async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result; - /// Delete a single path from remote storage. /// /// If the operation fails because of timeout or cancellation, the root cause of the error will be @@ -425,33 +460,6 @@ impl GenericRemoteStorage> { } } - pub async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result { - match self { - Self::LocalFs(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive, cancel) - .await - } - Self::AwsS3(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive, cancel) - .await - } - Self::AzureBlob(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive, cancel) - .await - } - Self::Unreliable(s) => { - s.download_byte_range(from, start_inclusive, end_exclusive, cancel) - .await - } - } - } - /// See [`RemoteStorage::delete`] pub async fn delete( &self, @@ -573,20 +581,6 @@ impl GenericRemoteStorage { }) } - /// Downloads the storage object into the `to_path` provided. - /// `byte_range` could be specified to dowload only a part of the file, if needed. - pub async fn download_storage_object( - &self, - byte_range: Option<(u64, Option)>, - from: &RemotePath, - cancel: &CancellationToken, - ) -> Result { - match byte_range { - Some((start, end)) => self.download_byte_range(from, start, end, cancel).await, - None => self.download(from, &DownloadOpts::default(), cancel).await, - } - } - /// The name of the bucket/container/etc. pub fn bucket_name(&self) -> Option<&str> { match self { @@ -660,6 +654,76 @@ impl ConcurrencyLimiter { mod tests { use super::*; + /// DownloadOpts::byte_range() should generate (inclusive, exclusive) ranges + /// with optional end bound, or None when unbounded. + #[test] + fn download_opts_byte_range() { + // Consider using test_case or a similar table-driven test framework. + let cases = [ + // (byte_start, byte_end, expected) + (Bound::Unbounded, Bound::Unbounded, None), + (Bound::Unbounded, Bound::Included(7), Some((0, Some(8)))), + (Bound::Unbounded, Bound::Excluded(7), Some((0, Some(7)))), + (Bound::Included(3), Bound::Unbounded, Some((3, None))), + (Bound::Included(3), Bound::Included(7), Some((3, Some(8)))), + (Bound::Included(3), Bound::Excluded(7), Some((3, Some(7)))), + (Bound::Excluded(3), Bound::Unbounded, Some((4, None))), + (Bound::Excluded(3), Bound::Included(7), Some((4, Some(8)))), + (Bound::Excluded(3), Bound::Excluded(7), Some((4, Some(7)))), + // 1-sized ranges are fine, 0 aren't and will panic (separate test). + (Bound::Included(3), Bound::Included(3), Some((3, Some(4)))), + (Bound::Included(3), Bound::Excluded(4), Some((3, Some(4)))), + ]; + + for (byte_start, byte_end, expect) in cases { + let opts = DownloadOpts { + byte_start, + byte_end, + ..Default::default() + }; + let result = opts.byte_range(); + assert_eq!( + result, expect, + "byte_start={byte_start:?} byte_end={byte_end:?}" + ); + + // Check generated HTTP header, which uses an inclusive range. + let expect_header = expect.map(|(start, end)| match end { + Some(end) => format!("bytes={start}-{}", end - 1), // inclusive end + None => format!("bytes={start}-"), + }); + assert_eq!( + opts.byte_range_header(), + expect_header, + "byte_start={byte_start:?} byte_end={byte_end:?}" + ); + } + } + + /// DownloadOpts::byte_range() zero-sized byte range should panic. + #[test] + #[should_panic] + fn download_opts_byte_range_zero() { + DownloadOpts { + byte_start: Bound::Included(3), + byte_end: Bound::Excluded(3), + ..Default::default() + } + .byte_range(); + } + + /// DownloadOpts::byte_range() negative byte range should panic. + #[test] + #[should_panic] + fn download_opts_byte_range_negative() { + DownloadOpts { + byte_start: Bound::Included(3), + byte_end: Bound::Included(2), + ..Default::default() + } + .byte_range(); + } + #[test] fn test_object_name() { let k = RemotePath::new(Utf8Path::new("a/b/c")).unwrap(); diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index d912b94c74..93a052139b 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -506,54 +506,7 @@ impl RemoteStorage for LocalFs { return Err(DownloadError::Unmodified); } - let source = ReaderStream::new( - fs::OpenOptions::new() - .read(true) - .open(&target_path) - .await - .with_context(|| { - format!("Failed to open source file {target_path:?} to use in the download") - }) - .map_err(DownloadError::Other)?, - ); - - let metadata = self - .read_storage_metadata(&target_path) - .await - .map_err(DownloadError::Other)?; - - let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone()); - let source = crate::support::DownloadStream::new(cancel_or_timeout, source); - - Ok(Download { - metadata, - last_modified: file_metadata - .modified() - .map_err(|e| DownloadError::Other(anyhow::anyhow!(e).context("Reading mtime")))?, - etag, - download_stream: Box::pin(source), - }) - } - - async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result { - if let Some(end_exclusive) = end_exclusive { - if end_exclusive <= start_inclusive { - return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) is not less than end_exclusive ({end_exclusive:?})"))); - }; - if start_inclusive == end_exclusive.saturating_sub(1) { - return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) and end_exclusive ({end_exclusive:?}) difference is zero bytes"))); - } - } - - let target_path = from.with_base(&self.storage_root); - let file_metadata = file_metadata(&target_path).await?; - let mut source = tokio::fs::OpenOptions::new() + let mut file = fs::OpenOptions::new() .read(true) .open(&target_path) .await @@ -562,31 +515,29 @@ impl RemoteStorage for LocalFs { }) .map_err(DownloadError::Other)?; - let len = source - .metadata() - .await - .context("query file length") - .map_err(DownloadError::Other)? - .len(); + let mut take = file_metadata.len(); + if let Some((start, end)) = opts.byte_range() { + if start > 0 { + file.seek(io::SeekFrom::Start(start)) + .await + .context("Failed to seek to the range start in a local storage file") + .map_err(DownloadError::Other)?; + } + if let Some(end) = end { + take = end - start; + } + } - source - .seek(io::SeekFrom::Start(start_inclusive)) - .await - .context("Failed to seek to the range start in a local storage file") - .map_err(DownloadError::Other)?; + let source = ReaderStream::new(file.take(take)); let metadata = self .read_storage_metadata(&target_path) .await .map_err(DownloadError::Other)?; - let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive); - let source = ReaderStream::new(source); - let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone()); let source = crate::support::DownloadStream::new(cancel_or_timeout, source); - let etag = mock_etag(&file_metadata); Ok(Download { metadata, last_modified: file_metadata @@ -688,7 +639,7 @@ mod fs_tests { use super::*; use camino_tempfile::tempdir; - use std::{collections::HashMap, io::Write}; + use std::{collections::HashMap, io::Write, ops::Bound}; async fn read_and_check_metadata( storage: &LocalFs, @@ -804,10 +755,12 @@ mod fs_tests { let (first_part_local, second_part_local) = uploaded_bytes.split_at(3); let first_part_download = storage - .download_byte_range( + .download( &upload_target, - 0, - Some(first_part_local.len() as u64), + &DownloadOpts { + byte_end: Bound::Excluded(first_part_local.len() as u64), + ..Default::default() + }, &cancel, ) .await?; @@ -823,10 +776,15 @@ mod fs_tests { ); let second_part_download = storage - .download_byte_range( + .download( &upload_target, - first_part_local.len() as u64, - Some((first_part_local.len() + second_part_local.len()) as u64), + &DownloadOpts { + byte_start: Bound::Included(first_part_local.len() as u64), + byte_end: Bound::Excluded( + (first_part_local.len() + second_part_local.len()) as u64, + ), + ..Default::default() + }, &cancel, ) .await?; @@ -842,7 +800,14 @@ mod fs_tests { ); let suffix_bytes = storage - .download_byte_range(&upload_target, 13, None, &cancel) + .download( + &upload_target, + &DownloadOpts { + byte_start: Bound::Included(13), + ..Default::default() + }, + &cancel, + ) .await? .download_stream; let suffix_bytes = aggregate(suffix_bytes).await?; @@ -850,7 +815,7 @@ mod fs_tests { assert_eq!(upload_name, suffix); let all_bytes = storage - .download_byte_range(&upload_target, 0, None, &cancel) + .download(&upload_target, &DownloadOpts::default(), &cancel) .await? .download_stream; let all_bytes = aggregate(all_bytes).await?; @@ -861,48 +826,26 @@ mod fs_tests { } #[tokio::test] - async fn download_file_range_negative() -> anyhow::Result<()> { - let (storage, cancel) = create_storage()?; + #[should_panic(expected = "at or before start")] + async fn download_file_range_negative() { + let (storage, cancel) = create_storage().unwrap(); let upload_name = "upload_1"; - let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?; + let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel) + .await + .unwrap(); - let start = 1_000_000_000; - let end = start + 1; - match storage - .download_byte_range( + storage + .download( &upload_target, - start, - Some(end), // exclusive end + &DownloadOpts { + byte_start: Bound::Included(10), + byte_end: Bound::Excluded(10), + ..Default::default() + }, &cancel, ) .await - { - Ok(_) => panic!("Should not allow downloading wrong ranges"), - Err(e) => { - let error_string = e.to_string(); - assert!(error_string.contains("zero bytes")); - assert!(error_string.contains(&start.to_string())); - assert!(error_string.contains(&end.to_string())); - } - } - - let start = 10000; - let end = 234; - assert!(start > end, "Should test an incorrect range"); - match storage - .download_byte_range(&upload_target, start, Some(end), &cancel) - .await - { - Ok(_) => panic!("Should not allow downloading wrong ranges"), - Err(e) => { - let error_string = e.to_string(); - assert!(error_string.contains("Invalid range")); - assert!(error_string.contains(&start.to_string())); - assert!(error_string.contains(&end.to_string())); - } - } - - Ok(()) + .unwrap(); } #[tokio::test] @@ -945,10 +888,12 @@ mod fs_tests { let (first_part_local, _) = uploaded_bytes.split_at(3); let partial_download_with_metadata = storage - .download_byte_range( + .download( &upload_target, - 0, - Some(first_part_local.len() as u64), + &DownloadOpts { + byte_end: Bound::Excluded(first_part_local.len() as u64), + ..Default::default() + }, &cancel, ) .await?; diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index ec7c047565..f950f2886c 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -804,34 +804,7 @@ impl RemoteStorage for S3Bucket { bucket: self.bucket_name.clone(), key: self.relative_path_to_s3_object(from), etag: opts.etag.as_ref().map(|e| e.to_string()), - range: None, - }, - cancel, - ) - .await - } - - async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result { - // S3 accepts ranges as https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 - // and needs both ends to be exclusive - let end_inclusive = end_exclusive.map(|end| end.saturating_sub(1)); - let range = Some(match end_inclusive { - Some(end_inclusive) => format!("bytes={start_inclusive}-{end_inclusive}"), - None => format!("bytes={start_inclusive}-"), - }); - - self.download_object( - GetObjectRequest { - bucket: self.bucket_name.clone(), - key: self.relative_path_to_s3_object(from), - etag: None, - range, + range: opts.byte_range_header(), }, cancel, ) diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 05f82b5a5a..10db53971c 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -170,28 +170,13 @@ impl RemoteStorage for UnreliableWrapper { opts: &DownloadOpts, cancel: &CancellationToken, ) -> Result { + // Note: We treat any byte range as an "attempt" of the same operation. + // We don't pay attention to the ranges. That's good enough for now. self.attempt(RemoteOp::Download(from.clone())) .map_err(DownloadError::Other)?; self.inner.download(from, opts, cancel).await } - async fn download_byte_range( - &self, - from: &RemotePath, - start_inclusive: u64, - end_exclusive: Option, - cancel: &CancellationToken, - ) -> Result { - // Note: We treat any download_byte_range as an "attempt" of the same - // operation. We don't pay attention to the ranges. That's good enough - // for now. - self.attempt(RemoteOp::Download(from.clone())) - .map_err(DownloadError::Other)?; - self.inner - .download_byte_range(from, start_inclusive, end_exclusive, cancel) - .await - } - async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> { self.delete_inner(path, true, cancel).await } diff --git a/libs/remote_storage/tests/common/tests.rs b/libs/remote_storage/tests/common/tests.rs index e38cfb3ef0..e6f33fc3f8 100644 --- a/libs/remote_storage/tests/common/tests.rs +++ b/libs/remote_storage/tests/common/tests.rs @@ -2,6 +2,7 @@ use anyhow::Context; use camino::Utf8Path; use futures::StreamExt; use remote_storage::{DownloadError, DownloadOpts, ListingMode, ListingObject, RemotePath}; +use std::ops::Bound; use std::sync::Arc; use std::{collections::HashSet, num::NonZeroU32}; use test_context::test_context; @@ -293,7 +294,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< // Full range (end specified) let dl = ctx .client - .download_byte_range(&path, 0, Some(len as u64), &cancel) + .download( + &path, + &DownloadOpts { + byte_start: Bound::Included(0), + byte_end: Bound::Excluded(len as u64), + ..Default::default() + }, + &cancel, + ) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); @@ -301,7 +310,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< // partial range (end specified) let dl = ctx .client - .download_byte_range(&path, 4, Some(10), &cancel) + .download( + &path, + &DownloadOpts { + byte_start: Bound::Included(4), + byte_end: Bound::Excluded(10), + ..Default::default() + }, + &cancel, + ) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[4..10]); @@ -309,7 +326,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< // partial range (end beyond real end) let dl = ctx .client - .download_byte_range(&path, 8, Some(len as u64 * 100), &cancel) + .download( + &path, + &DownloadOpts { + byte_start: Bound::Included(8), + byte_end: Bound::Excluded(len as u64 * 100), + ..Default::default() + }, + &cancel, + ) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[8..]); @@ -317,7 +342,14 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< // Partial range (end unspecified) let dl = ctx .client - .download_byte_range(&path, 4, None, &cancel) + .download( + &path, + &DownloadOpts { + byte_start: Bound::Included(4), + ..Default::default() + }, + &cancel, + ) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig[4..]); @@ -325,7 +357,14 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result< // Full range (end unspecified) let dl = ctx .client - .download_byte_range(&path, 0, None, &cancel) + .download( + &path, + &DownloadOpts { + byte_start: Bound::Included(0), + ..Default::default() + }, + &cancel, + ) .await?; let buf = download_to_vec(dl).await?; assert_eq!(&buf, &orig); diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index 7b735875b7..5bd6f4bedc 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -31,9 +31,12 @@ pub enum Scope { /// The scope used by pageservers in upcalls to storage controller and cloud control plane #[serde(rename = "generations_api")] GenerationsApi, - /// Allows access to control plane managment API and some storage controller endpoints. + /// Allows access to control plane managment API and all storage controller endpoints. Admin, + /// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration) + Infra, + /// Allows access to storage controller APIs used by the scrubber, to interrogate the state /// of a tenant & post scrub results. Scrubber, diff --git a/libs/vm_monitor/src/runner.rs b/libs/vm_monitor/src/runner.rs index 36f8573a38..8605314ba9 100644 --- a/libs/vm_monitor/src/runner.rs +++ b/libs/vm_monitor/src/runner.rs @@ -79,8 +79,7 @@ pub struct Config { /// memory. /// /// The default value of `0.15` means that we *guarantee* sending upscale requests if the - /// cgroup is using more than 85% of total memory (even if we're *not* separately reserving - /// memory for the file cache). + /// cgroup is using more than 85% of total memory. cgroup_min_overhead_fraction: f64, cgroup_downscale_threshold_buffer_bytes: u64, @@ -97,24 +96,12 @@ impl Default for Config { } impl Config { - fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 { - // If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup, - // and thus be non-reclaimable, so we should allow for additional memory usage. - // - // If the file cache sits on disk, our desired stable system state is for it to be fully - // page cached (its contents should only be paged to/from disk in situations where we can't - // upscale fast enough). Page-cached memory is reclaimable, so we need to lower the - // threshold for non-reclaimable memory so we scale up *before* the kernel starts paging - // out the file cache. - let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size); - - // Even if we're not separately making room for the file cache (if it's in tmpfs), we still - // want our threshold to be met gracefully instead of letting postgres get OOM-killed. + fn cgroup_threshold(&self, total_mem: u64) -> u64 { + // We want our threshold to be met gracefully instead of letting postgres get OOM-killed + // (or if there's room, spilling to swap). // So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory // remaining above the threshold. - let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64; - - memory_remaining_for_cgroup.min(max_threshold) + (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64 } } @@ -149,11 +136,6 @@ impl Runner { let mem = get_total_system_memory(); - let mut file_cache_disk_size = 0; - - // We need to process file cache initialization before cgroup initialization, so that the memory - // allocated to the file cache is appropriately taken into account when we decide the cgroup's - // memory limits. if let Some(connstr) = &args.pgconnstr { info!("initializing file cache"); let config = FileCacheConfig::default(); @@ -184,7 +166,6 @@ impl Runner { info!("file cache size actually got set to {actual_size}") } - file_cache_disk_size = actual_size; state.filecache = Some(file_cache); } @@ -207,7 +188,7 @@ impl Runner { cgroup.watch(hist_tx).await }); - let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size); + let threshold = state.config.cgroup_threshold(mem); info!(threshold, "set initial cgroup threshold",); state.cgroup = Some(CgroupState { @@ -259,9 +240,7 @@ impl Runner { return Ok((false, status.to_owned())); } - let new_threshold = self - .config - .cgroup_threshold(usable_system_memory, expected_file_cache_size); + let new_threshold = self.config.cgroup_threshold(usable_system_memory); let current = last_history.avg_non_reclaimable; @@ -282,13 +261,11 @@ impl Runner { // The downscaling has been approved. Downscale the file cache, then the cgroup. let mut status = vec![]; - let mut file_cache_disk_size = 0; if let Some(file_cache) = &mut self.filecache { let actual_usage = file_cache .set_file_cache_size(expected_file_cache_size) .await .context("failed to set file cache size")?; - file_cache_disk_size = actual_usage; let message = format!( "set file cache size to {} MiB", bytes_to_mebibytes(actual_usage), @@ -298,9 +275,7 @@ impl Runner { } if let Some(cgroup) = &mut self.cgroup { - let new_threshold = self - .config - .cgroup_threshold(usable_system_memory, file_cache_disk_size); + let new_threshold = self.config.cgroup_threshold(usable_system_memory); let message = format!( "set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB", @@ -329,7 +304,6 @@ impl Runner { let new_mem = resources.mem; let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes); - let mut file_cache_disk_size = 0; if let Some(file_cache) = &mut self.filecache { let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory); info!( @@ -342,7 +316,6 @@ impl Runner { .set_file_cache_size(expected_usage) .await .context("failed to set file cache size")?; - file_cache_disk_size = actual_usage; if actual_usage != expected_usage { warn!( @@ -354,9 +327,7 @@ impl Runner { } if let Some(cgroup) = &mut self.cgroup { - let new_threshold = self - .config - .cgroup_threshold(usable_system_memory, file_cache_disk_size); + let new_threshold = self.config.cgroup_threshold(usable_system_memory); info!( "set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB", diff --git a/pageserver/benches/bench_ingest.rs b/pageserver/benches/bench_ingest.rs index 72cbb6beab..821c8008a9 100644 --- a/pageserver/benches/bench_ingest.rs +++ b/pageserver/benches/bench_ingest.rs @@ -164,11 +164,7 @@ fn criterion_benchmark(c: &mut Criterion) { let conf: &'static PageServerConf = Box::leak(Box::new( pageserver::config::PageServerConf::dummy_conf(temp_dir.path().to_path_buf()), )); - virtual_file::init( - 16384, - virtual_file::io_engine_for_bench(), - pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, - ); + virtual_file::init(16384, virtual_file::io_engine_for_bench()); page_cache::init(conf.page_cache_size); { diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index 592f1ded0d..4d76c66905 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -540,10 +540,13 @@ impl Client { .map_err(Error::ReceiveBody) } - /// Configs io buffer alignment at runtime. - pub async fn put_io_alignment(&self, align: usize) -> Result<()> { - let uri = format!("{}/v1/io_alignment", self.mgmt_api_endpoint); - self.request(Method::PUT, uri, align) + /// Configs io mode at runtime. + pub async fn put_io_mode( + &self, + mode: &pageserver_api::models::virtual_file::IoMode, + ) -> Result<()> { + let uri = format!("{}/v1/io_mode", self.mgmt_api_endpoint); + self.request(Method::PUT, uri, mode) .await? .json() .await diff --git a/pageserver/ctl/src/layer_map_analyzer.rs b/pageserver/ctl/src/layer_map_analyzer.rs index adc090823d..151b94cf62 100644 --- a/pageserver/ctl/src/layer_map_analyzer.rs +++ b/pageserver/ctl/src/layer_map_analyzer.rs @@ -152,11 +152,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> { let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); // Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree. - pageserver::virtual_file::init( - 10, - virtual_file::api::IoEngineKind::StdFs, - pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, - ); + pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); pageserver::page_cache::init(100); let mut total_delta_layers = 0usize; diff --git a/pageserver/ctl/src/layers.rs b/pageserver/ctl/src/layers.rs index dd753398e2..fd948bf2ef 100644 --- a/pageserver/ctl/src/layers.rs +++ b/pageserver/ctl/src/layers.rs @@ -59,7 +59,7 @@ pub(crate) enum LayerCmd { async fn read_delta_file(path: impl AsRef, ctx: &RequestContext) -> Result<()> { let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path"); - virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, 1); + virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); page_cache::init(100); let file = VirtualFile::open(path, ctx).await?; let file_id = page_cache::next_file_id(); @@ -190,11 +190,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { new_tenant_id, new_timeline_id, } => { - pageserver::virtual_file::init( - 10, - virtual_file::api::IoEngineKind::StdFs, - pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, - ); + pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); pageserver::page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); diff --git a/pageserver/ctl/src/main.rs b/pageserver/ctl/src/main.rs index cf001ef0d5..c96664d346 100644 --- a/pageserver/ctl/src/main.rs +++ b/pageserver/ctl/src/main.rs @@ -26,7 +26,7 @@ use pageserver::{ tenant::{dump_layerfile_from_path, metadata::TimelineMetadata}, virtual_file, }; -use pageserver_api::{config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, shard::TenantShardId}; +use pageserver_api::shard::TenantShardId; use postgres_ffi::ControlFileData; use remote_storage::{RemotePath, RemoteStorageConfig}; use tokio_util::sync::CancellationToken; @@ -205,11 +205,7 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> { async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> { // Basic initialization of things that don't change after startup - virtual_file::init( - 10, - virtual_file::api::IoEngineKind::StdFs, - DEFAULT_IO_BUFFER_ALIGNMENT, - ); + virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs); page_cache::init(100); let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error); dump_layerfile_from_path(path, true, &ctx).await diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index ac4a732377..b2df01714d 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -59,9 +59,9 @@ pub(crate) struct Args { #[clap(long)] set_io_engine: Option, - /// Before starting the benchmark, live-reconfigure the pageserver to use specified alignment for io buffers. + /// Before starting the benchmark, live-reconfigure the pageserver to use specified io mode (buffered vs. direct). #[clap(long)] - set_io_alignment: Option, + set_io_mode: Option, targets: Option>, } @@ -129,8 +129,8 @@ async fn main_impl( mgmt_api_client.put_io_engine(engine_str).await?; } - if let Some(align) = args.set_io_alignment { - mgmt_api_client.put_io_alignment(align).await?; + if let Some(mode) = &args.set_io_mode { + mgmt_api_client.put_io_mode(mode).await?; } // discover targets diff --git a/pageserver/src/auth.rs b/pageserver/src/auth.rs index 9e3dedb75a..5c931fcfdb 100644 --- a/pageserver/src/auth.rs +++ b/pageserver/src/auth.rs @@ -14,14 +14,19 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } (Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope (Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope - (Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Pageserver auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::SafekeeperData + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Pageserver auth", + claims.scope + ) + .into(), + )), } } diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 593ca6db2d..f71a3d2653 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -125,8 +125,7 @@ fn main() -> anyhow::Result<()> { // after setting up logging, log the effective IO engine choice and read path implementations info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); - info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings"); - info!(?conf.io_buffer_alignment, "starting with setting for IO buffer alignment"); + info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode"); // The tenants directory contains all the pageserver local disk state. // Create if not exists and make sure all the contents are durable before proceeding. @@ -168,11 +167,7 @@ fn main() -> anyhow::Result<()> { let scenario = failpoint_support::init(); // Basic initialization of things that don't change after startup - virtual_file::init( - conf.max_file_descriptors, - conf.virtual_file_io_engine, - conf.io_buffer_alignment, - ); + virtual_file::init(conf.max_file_descriptors, conf.virtual_file_io_engine); page_cache::init(conf.page_cache_size); start_pageserver(launch_ts, conf).context("Failed to start pageserver")?; diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index e15f1c791b..8db78285e4 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -174,9 +174,7 @@ pub struct PageServerConf { pub l0_flush: crate::l0_flush::L0FlushConfig, /// Direct IO settings - pub virtual_file_direct_io: virtual_file::DirectIoMode, - - pub io_buffer_alignment: usize, + pub virtual_file_io_mode: virtual_file::IoMode, } /// Token for authentication to safekeepers @@ -325,11 +323,10 @@ impl PageServerConf { image_compression, ephemeral_bytes_per_memory_kb, l0_flush, - virtual_file_direct_io, + virtual_file_io_mode, concurrent_tenant_warmup, concurrent_tenant_size_logical_size_queries, virtual_file_io_engine, - io_buffer_alignment, tenant_config, } = config_toml; @@ -368,8 +365,6 @@ impl PageServerConf { max_vectored_read_bytes, image_compression, ephemeral_bytes_per_memory_kb, - virtual_file_direct_io, - io_buffer_alignment, // ------------------------------------------------------------ // fields that require additional validation or custom handling @@ -408,6 +403,7 @@ impl PageServerConf { l0_flush: l0_flush .map(crate::l0_flush::L0FlushConfig::from) .unwrap_or_default(), + virtual_file_io_mode: virtual_file_io_mode.unwrap_or(virtual_file::IoMode::preferred()), }; // ------------------------------------------------------------ diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 94375e62b6..2985ab1efb 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -17,6 +17,7 @@ use hyper::header; use hyper::StatusCode; use hyper::{Body, Request, Response, Uri}; use metrics::launch_timestamp::LaunchTimestamp; +use pageserver_api::models::virtual_file::IoMode; use pageserver_api::models::AuxFilePolicy; use pageserver_api::models::DownloadRemoteLayersTaskSpawnRequest; use pageserver_api::models::IngestAuxFilesRequest; @@ -703,6 +704,8 @@ async fn timeline_archival_config_handler( let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?; let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?; + let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn); + let request_data: TimelineArchivalConfigRequest = json_request(&mut request).await?; check_permission(&request, Some(tenant_shard_id.tenant_id))?; let state = get_state(&request); @@ -713,7 +716,7 @@ async fn timeline_archival_config_handler( .get_attached_tenant_shard(tenant_shard_id)?; tenant - .apply_timeline_archival_config(timeline_id, request_data.state) + .apply_timeline_archival_config(timeline_id, request_data.state, ctx) .await?; Ok::<_, ApiError>(()) } @@ -2379,17 +2382,13 @@ async fn put_io_engine_handler( json_response(StatusCode::OK, ()) } -async fn put_io_alignment_handler( +async fn put_io_mode_handler( mut r: Request, _cancel: CancellationToken, ) -> Result, ApiError> { check_permission(&r, None)?; - let align: usize = json_request(&mut r).await?; - crate::virtual_file::set_io_buffer_alignment(align).map_err(|align| { - ApiError::PreconditionFailed( - format!("Requested io alignment ({align}) is not a power of two").into(), - ) - })?; + let mode: IoMode = json_request(&mut r).await?; + crate::virtual_file::set_io_mode(mode); json_response(StatusCode::OK, ()) } @@ -3080,9 +3079,7 @@ pub fn make_router( |r| api_handler(r, timeline_collect_keyspace), ) .put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler)) - .put("/v1/io_alignment", |r| { - api_handler(r, put_io_alignment_handler) - }) + .put("/v1/io_mode", |r| api_handler(r, put_io_mode_handler)) .put( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/force_aux_policy_switch", |r| api_handler(r, force_aux_policy_switch_handler), diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 29f682c62a..d2818d04dc 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -38,6 +38,7 @@ use std::future::Future; use std::sync::Weak; use std::time::SystemTime; use storage_broker::BrokerClientChannel; +use timeline::offload::offload_timeline; use tokio::io::BufReader; use tokio::sync::watch; use tokio::task::JoinSet; @@ -287,9 +288,13 @@ pub struct Tenant { /// During timeline creation, we first insert the TimelineId to the /// creating map, then `timelines`, then remove it from the creating map. - /// **Lock order**: if acquring both, acquire`timelines` before `timelines_creating` + /// **Lock order**: if acquiring both, acquire`timelines` before `timelines_creating` timelines_creating: std::sync::Mutex>, + /// Possibly offloaded and archived timelines + /// **Lock order**: if acquiring both, acquire`timelines` before `timelines_offloaded` + timelines_offloaded: Mutex>>, + // This mutex prevents creation of new timelines during GC. // Adding yet another mutex (in addition to `timelines`) is needed because holding // `timelines` mutex during all GC iteration @@ -484,6 +489,65 @@ impl WalRedoManager { } } +pub struct OffloadedTimeline { + pub tenant_shard_id: TenantShardId, + pub timeline_id: TimelineId, + pub ancestor_timeline_id: Option, + + // TODO: once we persist offloaded state, make this lazily constructed + pub remote_client: Arc, + + /// Prevent two tasks from deleting the timeline at the same time. If held, the + /// timeline is being deleted. If 'true', the timeline has already been deleted. + pub delete_progress: Arc>, +} + +impl OffloadedTimeline { + fn from_timeline(timeline: &Timeline) -> Self { + Self { + tenant_shard_id: timeline.tenant_shard_id, + timeline_id: timeline.timeline_id, + ancestor_timeline_id: timeline.get_ancestor_timeline_id(), + + remote_client: timeline.remote_client.clone(), + delete_progress: timeline.delete_progress.clone(), + } + } +} + +#[derive(Clone)] +pub enum TimelineOrOffloaded { + Timeline(Arc), + Offloaded(Arc), +} + +impl TimelineOrOffloaded { + pub fn tenant_shard_id(&self) -> TenantShardId { + match self { + TimelineOrOffloaded::Timeline(timeline) => timeline.tenant_shard_id, + TimelineOrOffloaded::Offloaded(offloaded) => offloaded.tenant_shard_id, + } + } + pub fn timeline_id(&self) -> TimelineId { + match self { + TimelineOrOffloaded::Timeline(timeline) => timeline.timeline_id, + TimelineOrOffloaded::Offloaded(offloaded) => offloaded.timeline_id, + } + } + pub fn delete_progress(&self) -> &Arc> { + match self { + TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress, + TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress, + } + } + pub fn remote_client(&self) -> &Arc { + match self { + TimelineOrOffloaded::Timeline(timeline) => &timeline.remote_client, + TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.remote_client, + } + } +} + #[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum GetTimelineError { #[error("Timeline is shutting down")] @@ -1406,52 +1470,192 @@ impl Tenant { } } - pub(crate) async fn apply_timeline_archival_config( - &self, + fn check_to_be_archived_has_no_unarchived_children( timeline_id: TimelineId, - state: TimelineArchivalState, + timelines: &std::sync::MutexGuard<'_, HashMap>>, + ) -> Result<(), TimelineArchivalError> { + let children: Vec = timelines + .iter() + .filter_map(|(id, entry)| { + if entry.get_ancestor_timeline_id() != Some(timeline_id) { + return None; + } + if entry.is_archived() == Some(true) { + return None; + } + Some(*id) + }) + .collect(); + + if !children.is_empty() { + return Err(TimelineArchivalError::HasUnarchivedChildren(children)); + } + Ok(()) + } + + fn check_ancestor_of_to_be_unarchived_is_not_archived( + ancestor_timeline_id: TimelineId, + timelines: &std::sync::MutexGuard<'_, HashMap>>, + offloaded_timelines: &std::sync::MutexGuard< + '_, + HashMap>, + >, + ) -> Result<(), TimelineArchivalError> { + let has_archived_parent = + if let Some(ancestor_timeline) = timelines.get(&ancestor_timeline_id) { + ancestor_timeline.is_archived() == Some(true) + } else if offloaded_timelines.contains_key(&ancestor_timeline_id) { + true + } else { + error!("ancestor timeline {ancestor_timeline_id} not found"); + if cfg!(debug_assertions) { + panic!("ancestor timeline {ancestor_timeline_id} not found"); + } + return Err(TimelineArchivalError::NotFound); + }; + if has_archived_parent { + return Err(TimelineArchivalError::HasArchivedParent( + ancestor_timeline_id, + )); + } + Ok(()) + } + + fn check_to_be_unarchived_timeline_has_no_archived_parent( + timeline: &Arc, + ) -> Result<(), TimelineArchivalError> { + if let Some(ancestor_timeline) = timeline.ancestor_timeline() { + if ancestor_timeline.is_archived() == Some(true) { + return Err(TimelineArchivalError::HasArchivedParent( + ancestor_timeline.timeline_id, + )); + } + } + Ok(()) + } + + /// Loads the specified (offloaded) timeline from S3 and attaches it as a loaded timeline + async fn unoffload_timeline( + self: &Arc, + timeline_id: TimelineId, + ctx: RequestContext, + ) -> Result, TimelineArchivalError> { + let cancel = self.cancel.clone(); + let timeline_preload = self + .load_timeline_metadata(timeline_id, self.remote_storage.clone(), cancel) + .await; + + let index_part = match timeline_preload.index_part { + Ok(index_part) => { + debug!("remote index part exists for timeline {timeline_id}"); + index_part + } + Err(DownloadError::NotFound) => { + error!(%timeline_id, "index_part not found on remote"); + return Err(TimelineArchivalError::NotFound); + } + Err(e) => { + // Some (possibly ephemeral) error happened during index_part download. + warn!(%timeline_id, "Failed to load index_part from remote storage, failed creation? ({e})"); + return Err(TimelineArchivalError::Other( + anyhow::Error::new(e).context("downloading index_part from remote storage"), + )); + } + }; + let index_part = match index_part { + MaybeDeletedIndexPart::IndexPart(index_part) => index_part, + MaybeDeletedIndexPart::Deleted(_index_part) => { + info!("timeline is deleted according to index_part.json"); + return Err(TimelineArchivalError::NotFound); + } + }; + let remote_metadata = index_part.metadata.clone(); + let timeline_resources = self.build_timeline_resources(timeline_id); + self.load_remote_timeline( + timeline_id, + index_part, + remote_metadata, + timeline_resources, + &ctx, + ) + .await + .with_context(|| { + format!( + "failed to load remote timeline {} for tenant {}", + timeline_id, self.tenant_shard_id + ) + })?; + let timelines = self.timelines.lock().unwrap(); + if let Some(timeline) = timelines.get(&timeline_id) { + let mut offloaded_timelines = self.timelines_offloaded.lock().unwrap(); + if offloaded_timelines.remove(&timeline_id).is_none() { + warn!("timeline already removed from offloaded timelines"); + } + Ok(Arc::clone(timeline)) + } else { + warn!("timeline not available directly after attach"); + Err(TimelineArchivalError::Other(anyhow::anyhow!( + "timeline not available directly after attach" + ))) + } + } + + pub(crate) async fn apply_timeline_archival_config( + self: &Arc, + timeline_id: TimelineId, + new_state: TimelineArchivalState, + ctx: RequestContext, ) -> Result<(), TimelineArchivalError> { info!("setting timeline archival config"); - let timeline = { + // First part: figure out what is needed to do, and do validation + let timeline_or_unarchive_offloaded = 'outer: { let timelines = self.timelines.lock().unwrap(); let Some(timeline) = timelines.get(&timeline_id) else { - return Err(TimelineArchivalError::NotFound); + let offloaded_timelines = self.timelines_offloaded.lock().unwrap(); + let Some(offloaded) = offloaded_timelines.get(&timeline_id) else { + return Err(TimelineArchivalError::NotFound); + }; + if new_state == TimelineArchivalState::Archived { + // It's offloaded already, so nothing to do + return Ok(()); + } + if let Some(ancestor_timeline_id) = offloaded.ancestor_timeline_id { + Self::check_ancestor_of_to_be_unarchived_is_not_archived( + ancestor_timeline_id, + &timelines, + &offloaded_timelines, + )?; + } + break 'outer None; }; - if state == TimelineArchivalState::Unarchived { - if let Some(ancestor_timeline) = timeline.ancestor_timeline() { - if ancestor_timeline.is_archived() == Some(true) { - return Err(TimelineArchivalError::HasArchivedParent( - ancestor_timeline.timeline_id, - )); - } + // Do some validation. We release the timelines lock below, so there is potential + // for race conditions: these checks are more present to prevent misunderstandings of + // the API's capabilities, instead of serving as the sole way to defend their invariants. + match new_state { + TimelineArchivalState::Unarchived => { + Self::check_to_be_unarchived_timeline_has_no_archived_parent(timeline)? + } + TimelineArchivalState::Archived => { + Self::check_to_be_archived_has_no_unarchived_children(timeline_id, &timelines)? } } - - // Ensure that there are no non-archived child timelines - let children: Vec = timelines - .iter() - .filter_map(|(id, entry)| { - if entry.get_ancestor_timeline_id() != Some(timeline_id) { - return None; - } - if entry.is_archived() == Some(true) { - return None; - } - Some(*id) - }) - .collect(); - - if !children.is_empty() && state == TimelineArchivalState::Archived { - return Err(TimelineArchivalError::HasUnarchivedChildren(children)); - } - Arc::clone(timeline) + Some(Arc::clone(timeline)) }; + // Second part: unarchive timeline (if needed) + let timeline = if let Some(timeline) = timeline_or_unarchive_offloaded { + timeline + } else { + // Turn offloaded timeline into a non-offloaded one + self.unoffload_timeline(timeline_id, ctx).await? + }; + + // Third part: upload new timeline archival state and block until it is present in S3 let upload_needed = timeline .remote_client - .schedule_index_upload_for_timeline_archival_state(state)?; + .schedule_index_upload_for_timeline_archival_state(new_state)?; if upload_needed { info!("Uploading new state"); @@ -1884,7 +2088,7 @@ impl Tenant { /// /// Returns whether we have pending compaction task. async fn compaction_iteration( - &self, + self: &Arc, cancel: &CancellationToken, ctx: &RequestContext, ) -> Result { @@ -1905,21 +2109,28 @@ impl Tenant { // while holding the lock. Then drop the lock and actually perform the // compactions. We don't want to block everything else while the // compaction runs. - let timelines_to_compact = { + let timelines_to_compact_or_offload; + { let timelines = self.timelines.lock().unwrap(); - let timelines_to_compact = timelines + timelines_to_compact_or_offload = timelines .iter() .filter_map(|(timeline_id, timeline)| { - if timeline.is_active() { - Some((*timeline_id, timeline.clone())) - } else { + let (is_active, can_offload) = (timeline.is_active(), timeline.can_offload()); + let has_no_unoffloaded_children = { + !timelines + .iter() + .any(|(_id, tl)| tl.get_ancestor_timeline_id() == Some(*timeline_id)) + }; + let can_offload = can_offload && has_no_unoffloaded_children; + if (is_active, can_offload) == (false, false) { None + } else { + Some((*timeline_id, timeline.clone(), (is_active, can_offload))) } }) .collect::>(); drop(timelines); - timelines_to_compact - }; + } // Before doing any I/O work, check our circuit breaker if self.compaction_circuit_breaker.lock().unwrap().is_broken() { @@ -1929,20 +2140,34 @@ impl Tenant { let mut has_pending_task = false; - for (timeline_id, timeline) in &timelines_to_compact { - has_pending_task |= timeline - .compact(cancel, EnumSet::empty(), ctx) - .instrument(info_span!("compact_timeline", %timeline_id)) - .await - .inspect_err(|e| match e { - timeline::CompactionError::ShuttingDown => (), - timeline::CompactionError::Other(e) => { - self.compaction_circuit_breaker - .lock() - .unwrap() - .fail(&CIRCUIT_BREAKERS_BROKEN, e); - } - })?; + for (timeline_id, timeline, (can_compact, can_offload)) in &timelines_to_compact_or_offload + { + let pending_task_left = if *can_compact { + Some( + timeline + .compact(cancel, EnumSet::empty(), ctx) + .instrument(info_span!("compact_timeline", %timeline_id)) + .await + .inspect_err(|e| match e { + timeline::CompactionError::ShuttingDown => (), + timeline::CompactionError::Other(e) => { + self.compaction_circuit_breaker + .lock() + .unwrap() + .fail(&CIRCUIT_BREAKERS_BROKEN, e); + } + })?, + ) + } else { + None + }; + has_pending_task |= pending_task_left.unwrap_or(false); + if pending_task_left == Some(false) && *can_offload { + offload_timeline(self, timeline) + .instrument(info_span!("offload_timeline", %timeline_id)) + .await + .map_err(timeline::CompactionError::Other)?; + } } self.compaction_circuit_breaker @@ -2852,6 +3077,7 @@ impl Tenant { constructed_at: Instant::now(), timelines: Mutex::new(HashMap::new()), timelines_creating: Mutex::new(HashSet::new()), + timelines_offloaded: Mutex::new(HashMap::new()), gc_cs: tokio::sync::Mutex::new(()), walredo_mgr, remote_storage, diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 5324e1807d..a62a47f9a7 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -84,7 +84,7 @@ impl Drop for EphemeralFile { fn drop(&mut self) { // unlink the file // we are clear to do this, because we have entered a gate - let path = &self.buffered_writer.as_inner().as_inner().path; + let path = self.buffered_writer.as_inner().as_inner().path(); let res = std::fs::remove_file(path); if let Err(e) = res { if e.kind() != std::io::ErrorKind::NotFound { @@ -356,7 +356,7 @@ mod tests { } let file_contents = - std::fs::read(&file.buffered_writer.as_inner().as_inner().path).unwrap(); + std::fs::read(file.buffered_writer.as_inner().as_inner().path()).unwrap(); assert_eq!(file_contents, &content[0..cap]); let buffer_contents = file.buffered_writer.inspect_buffer(); @@ -392,7 +392,7 @@ mod tests { .buffered_writer .as_inner() .as_inner() - .path + .path() .metadata() .unwrap(); assert_eq!( diff --git a/pageserver/src/tenant/gc_block.rs b/pageserver/src/tenant/gc_block.rs index f7a7836a12..373779ddb8 100644 --- a/pageserver/src/tenant/gc_block.rs +++ b/pageserver/src/tenant/gc_block.rs @@ -141,14 +141,14 @@ impl GcBlock { Ok(()) } - pub(crate) fn before_delete(&self, timeline: &super::Timeline) { + pub(crate) fn before_delete(&self, timeline_id: &super::TimelineId) { let unblocked = { let mut g = self.reasons.lock().unwrap(); if g.is_empty() { return; } - g.remove(&timeline.timeline_id); + g.remove(timeline_id); BlockingReasons::clean_and_summarize(g).is_none() }; diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 9f7447a9ac..82c5702686 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -950,6 +950,7 @@ impl<'a> TenantDownloader<'a> { let cancel = &self.secondary_state.cancel; let opts = DownloadOpts { etag: prev_etag.cloned(), + ..Default::default() }; backoff::retry( diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 2acad666b8..8be7d7876f 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -573,7 +573,7 @@ impl DeltaLayerWriterInner { ensure!( metadata.len() <= S3_UPLOAD_LIMIT, "Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!", - file.path, + file.path(), metadata.len() ); @@ -791,7 +791,7 @@ impl DeltaLayerInner { max_vectored_read_bytes: Option, ctx: &RequestContext, ) -> anyhow::Result { - let file = VirtualFile::open(path, ctx) + let file = VirtualFile::open_v2(path, ctx) .await .context("open layer file")?; @@ -1022,7 +1022,7 @@ impl DeltaLayerInner { blob_meta.key, PageReconstructError::Other(anyhow!( "Failed to read blobs from virtual file {}: {}", - self.file.path, + self.file.path(), kind )), ); @@ -1048,7 +1048,7 @@ impl DeltaLayerInner { meta.meta.key, PageReconstructError::Other(anyhow!(e).context(format!( "Failed to decompress blob from virtual file {}", - self.file.path, + self.file.path(), ))), ); @@ -1066,7 +1066,7 @@ impl DeltaLayerInner { meta.meta.key, PageReconstructError::Other(anyhow!(e).context(format!( "Failed to deserialize blob from virtual file {}", - self.file.path, + self.file.path(), ))), ); @@ -1198,7 +1198,6 @@ impl DeltaLayerInner { let mut prev: Option<(Key, Lsn, BlobRef)> = None; let mut read_builder: Option = None; - let align = virtual_file::get_io_buffer_alignment(); let max_read_size = self .max_vectored_read_bytes @@ -1247,7 +1246,6 @@ impl DeltaLayerInner { offsets.end.pos(), meta, max_read_size, - align, )) } } else { diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 9b53fa9e18..de8155f455 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -389,7 +389,7 @@ impl ImageLayerInner { max_vectored_read_bytes: Option, ctx: &RequestContext, ) -> anyhow::Result { - let file = VirtualFile::open(path, ctx) + let file = VirtualFile::open_v2(path, ctx) .await .context("open layer file")?; let file_id = page_cache::next_file_id(); @@ -626,7 +626,7 @@ impl ImageLayerInner { meta.meta.key, PageReconstructError::Other(anyhow!(e).context(format!( "Failed to decompress blob from virtual file {}", - self.file.path, + self.file.path(), ))), ); @@ -647,7 +647,7 @@ impl ImageLayerInner { blob_meta.key, PageReconstructError::from(anyhow!( "Failed to read blobs from virtual file {}: {}", - self.file.path, + self.file.path(), kind )), ); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 1d79b2b74b..2fd4e699cf 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -7,6 +7,7 @@ pub(crate) mod handle; mod init; pub mod layer_manager; pub(crate) mod logical_size; +pub mod offload; pub mod span; pub mod uninit; mod walreceiver; @@ -1556,6 +1557,17 @@ impl Timeline { } } + /// Checks if the internal state of the timeline is consistent with it being able to be offloaded. + /// This is neccessary but not sufficient for offloading of the timeline as it might have + /// child timelines that are not offloaded yet. + pub(crate) fn can_offload(&self) -> bool { + if self.remote_client.is_archived() != Some(true) { + return false; + } + + true + } + /// Outermost timeline compaction operation; downloads needed layers. Returns whether we have pending /// compaction tasks. pub(crate) async fn compact( @@ -1818,7 +1830,6 @@ impl Timeline { self.current_state() == TimelineState::Active } - #[allow(unused)] pub(crate) fn is_archived(&self) -> Option { self.remote_client.is_archived() } diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index 90db08ea81..305c5758cc 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -15,7 +15,7 @@ use crate::{ tenant::{ metadata::TimelineMetadata, remote_timeline_client::{PersistIndexPartWithDeletedFlagError, RemoteTimelineClient}, - CreateTimelineCause, DeleteTimelineError, Tenant, + CreateTimelineCause, DeleteTimelineError, Tenant, TimelineOrOffloaded, }, }; @@ -24,12 +24,14 @@ use super::{Timeline, TimelineResources}; /// Mark timeline as deleted in S3 so we won't pick it up next time /// during attach or pageserver restart. /// See comment in persist_index_part_with_deleted_flag. -async fn set_deleted_in_remote_index(timeline: &Timeline) -> Result<(), DeleteTimelineError> { - match timeline - .remote_client +async fn set_deleted_in_remote_index( + timeline: &TimelineOrOffloaded, +) -> Result<(), DeleteTimelineError> { + let res = timeline + .remote_client() .persist_index_part_with_deleted_flag() - .await - { + .await; + match res { // If we (now, or already) marked it successfully as deleted, we can proceed Ok(()) | Err(PersistIndexPartWithDeletedFlagError::AlreadyDeleted(_)) => (), // Bail out otherwise @@ -127,9 +129,9 @@ pub(super) async fn delete_local_timeline_directory( } /// Removes remote layers and an index file after them. -async fn delete_remote_layers_and_index(timeline: &Timeline) -> anyhow::Result<()> { +async fn delete_remote_layers_and_index(timeline: &TimelineOrOffloaded) -> anyhow::Result<()> { timeline - .remote_client + .remote_client() .delete_all() .await .context("delete_all") @@ -137,27 +139,41 @@ async fn delete_remote_layers_and_index(timeline: &Timeline) -> anyhow::Result<( /// It is important that this gets called when DeletionGuard is being held. /// For more context see comments in [`DeleteTimelineFlow::prepare`] -async fn remove_timeline_from_tenant( +async fn remove_maybe_offloaded_timeline_from_tenant( tenant: &Tenant, - timeline: &Timeline, + timeline: &TimelineOrOffloaded, _: &DeletionGuard, // using it as a witness ) -> anyhow::Result<()> { // Remove the timeline from the map. + // This observes the locking order between timelines and timelines_offloaded let mut timelines = tenant.timelines.lock().unwrap(); + let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap(); + let offloaded_children_exist = timelines_offloaded + .iter() + .any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id())); let children_exist = timelines .iter() - .any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id)); - // XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`. - // We already deleted the layer files, so it's probably best to panic. - // (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart) - if children_exist { + .any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id())); + // XXX this can happen because of race conditions with branch creation. + // We already deleted the remote layer files, so it's probably best to panic. + if children_exist || offloaded_children_exist { panic!("Timeline grew children while we removed layer files"); } - timelines - .remove(&timeline.timeline_id) - .expect("timeline that we were deleting was concurrently removed from 'timelines' map"); + match timeline { + TimelineOrOffloaded::Timeline(timeline) => { + timelines.remove(&timeline.timeline_id).expect( + "timeline that we were deleting was concurrently removed from 'timelines' map", + ); + } + TimelineOrOffloaded::Offloaded(timeline) => { + timelines_offloaded + .remove(&timeline.timeline_id) + .expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map"); + } + } + drop(timelines_offloaded); drop(timelines); Ok(()) @@ -207,9 +223,11 @@ impl DeleteTimelineFlow { guard.mark_in_progress()?; // Now that the Timeline is in Stopping state, request all the related tasks to shut down. - timeline.shutdown(super::ShutdownMode::Hard).await; + if let TimelineOrOffloaded::Timeline(timeline) = &timeline { + timeline.shutdown(super::ShutdownMode::Hard).await; + } - tenant.gc_block.before_delete(&timeline); + tenant.gc_block.before_delete(&timeline.timeline_id()); fail::fail_point!("timeline-delete-before-index-deleted-at", |_| { Err(anyhow::anyhow!( @@ -285,15 +303,16 @@ impl DeleteTimelineFlow { guard.mark_in_progress()?; + let timeline = TimelineOrOffloaded::Timeline(timeline); Self::schedule_background(guard, tenant.conf, tenant, timeline); Ok(()) } - fn prepare( + pub(super) fn prepare( tenant: &Tenant, timeline_id: TimelineId, - ) -> Result<(Arc, DeletionGuard), DeleteTimelineError> { + ) -> Result<(TimelineOrOffloaded, DeletionGuard), DeleteTimelineError> { // Note the interaction between this guard and deletion guard. // Here we attempt to lock deletion guard when we're holding a lock on timelines. // This is important because when you take into account `remove_timeline_from_tenant` @@ -307,8 +326,14 @@ impl DeleteTimelineFlow { let timelines = tenant.timelines.lock().unwrap(); let timeline = match timelines.get(&timeline_id) { - Some(t) => t, - None => return Err(DeleteTimelineError::NotFound), + Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)), + None => { + let offloaded_timelines = tenant.timelines_offloaded.lock().unwrap(); + match offloaded_timelines.get(&timeline_id) { + Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)), + None => return Err(DeleteTimelineError::NotFound), + } + } }; // Ensure that there are no child timelines **attached to that pageserver**, @@ -334,30 +359,32 @@ impl DeleteTimelineFlow { // to remove the timeline from it. // Always if you have two locks that are taken in different order this can result in a deadlock. - let delete_progress = Arc::clone(&timeline.delete_progress); + let delete_progress = Arc::clone(timeline.delete_progress()); let delete_lock_guard = match delete_progress.try_lock_owned() { Ok(guard) => DeletionGuard(guard), Err(_) => { // Unfortunately if lock fails arc is consumed. return Err(DeleteTimelineError::AlreadyInProgress(Arc::clone( - &timeline.delete_progress, + timeline.delete_progress(), ))); } }; - timeline.set_state(TimelineState::Stopping); + if let TimelineOrOffloaded::Timeline(timeline) = &timeline { + timeline.set_state(TimelineState::Stopping); + } - Ok((Arc::clone(timeline), delete_lock_guard)) + Ok((timeline, delete_lock_guard)) } fn schedule_background( guard: DeletionGuard, conf: &'static PageServerConf, tenant: Arc, - timeline: Arc, + timeline: TimelineOrOffloaded, ) { - let tenant_shard_id = timeline.tenant_shard_id; - let timeline_id = timeline.timeline_id; + let tenant_shard_id = timeline.tenant_shard_id(); + let timeline_id = timeline.timeline_id(); task_mgr::spawn( task_mgr::BACKGROUND_RUNTIME.handle(), @@ -368,7 +395,9 @@ impl DeleteTimelineFlow { async move { if let Err(err) = Self::background(guard, conf, &tenant, &timeline).await { error!("Error: {err:#}"); - timeline.set_broken(format!("{err:#}")) + if let TimelineOrOffloaded::Timeline(timeline) = timeline { + timeline.set_broken(format!("{err:#}")) + } }; Ok(()) } @@ -380,15 +409,19 @@ impl DeleteTimelineFlow { mut guard: DeletionGuard, conf: &PageServerConf, tenant: &Tenant, - timeline: &Timeline, + timeline: &TimelineOrOffloaded, ) -> Result<(), DeleteTimelineError> { - delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?; + // Offloaded timelines have no local state + // TODO: once we persist offloaded information, delete the timeline from there, too + if let TimelineOrOffloaded::Timeline(timeline) = timeline { + delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?; + } delete_remote_layers_and_index(timeline).await?; pausable_failpoint!("in_progress_delete"); - remove_timeline_from_tenant(tenant, timeline, &guard).await?; + remove_maybe_offloaded_timeline_from_tenant(tenant, timeline, &guard).await?; *guard = Self::Finished; @@ -400,7 +433,7 @@ impl DeleteTimelineFlow { } } -struct DeletionGuard(OwnedMutexGuard); +pub(super) struct DeletionGuard(OwnedMutexGuard); impl Deref for DeletionGuard { type Target = DeleteTimelineFlow; diff --git a/pageserver/src/tenant/timeline/offload.rs b/pageserver/src/tenant/timeline/offload.rs new file mode 100644 index 0000000000..fb906d906b --- /dev/null +++ b/pageserver/src/tenant/timeline/offload.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use crate::tenant::{OffloadedTimeline, Tenant, TimelineOrOffloaded}; + +use super::{ + delete::{delete_local_timeline_directory, DeleteTimelineFlow, DeletionGuard}, + Timeline, +}; + +pub(crate) async fn offload_timeline( + tenant: &Tenant, + timeline: &Arc, +) -> anyhow::Result<()> { + tracing::info!("offloading archived timeline"); + let (timeline, guard) = DeleteTimelineFlow::prepare(tenant, timeline.timeline_id)?; + + let TimelineOrOffloaded::Timeline(timeline) = timeline else { + tracing::error!("timeline already offloaded, but given timeline object"); + return Ok(()); + }; + + // TODO extend guard mechanism above with method + // to make deletions possible while offloading is in progress + + // TODO mark timeline as offloaded in S3 + + let conf = &tenant.conf; + delete_local_timeline_directory(conf, tenant.tenant_shard_id, &timeline).await?; + + remove_timeline_from_tenant(tenant, &timeline, &guard).await?; + + { + let mut offloaded_timelines = tenant.timelines_offloaded.lock().unwrap(); + offloaded_timelines.insert( + timeline.timeline_id, + Arc::new(OffloadedTimeline::from_timeline(&timeline)), + ); + } + + Ok(()) +} + +/// It is important that this gets called when DeletionGuard is being held. +/// For more context see comments in [`DeleteTimelineFlow::prepare`] +async fn remove_timeline_from_tenant( + tenant: &Tenant, + timeline: &Timeline, + _: &DeletionGuard, // using it as a witness +) -> anyhow::Result<()> { + // Remove the timeline from the map. + let mut timelines = tenant.timelines.lock().unwrap(); + let children_exist = timelines + .iter() + .any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id)); + // XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`. + // We already deleted the layer files, so it's probably best to panic. + // (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart) + if children_exist { + panic!("Timeline grew children while we removed layer files"); + } + + timelines + .remove(&timeline.timeline_id) + .expect("timeline that we were deleting was concurrently removed from 'timelines' map"); + + drop(timelines); + + Ok(()) +} diff --git a/pageserver/src/tenant/vectored_blob_io.rs b/pageserver/src/tenant/vectored_blob_io.rs index 1faa6bab99..792c769b4f 100644 --- a/pageserver/src/tenant/vectored_blob_io.rs +++ b/pageserver/src/tenant/vectored_blob_io.rs @@ -194,8 +194,6 @@ pub(crate) struct ChunkedVectoredReadBuilder { /// Start offset and metadata for each blob in this read blobs_at: VecMap, max_read_size: Option, - /// Chunk size reads are coalesced into. - chunk_size: usize, } /// Computes x / d rounded up. @@ -204,6 +202,7 @@ fn div_round_up(x: usize, d: usize) -> usize { } impl ChunkedVectoredReadBuilder { + const CHUNK_SIZE: usize = virtual_file::get_io_buffer_alignment(); /// Start building a new vectored read. /// /// Note that by design, this does not check against reading more than `max_read_size` to @@ -214,21 +213,19 @@ impl ChunkedVectoredReadBuilder { end_offset: u64, meta: BlobMeta, max_read_size: Option, - chunk_size: usize, ) -> Self { let mut blobs_at = VecMap::default(); blobs_at .append(start_offset, meta) .expect("First insertion always succeeds"); - let start_blk_no = start_offset as usize / chunk_size; - let end_blk_no = div_round_up(end_offset as usize, chunk_size); + let start_blk_no = start_offset as usize / Self::CHUNK_SIZE; + let end_blk_no = div_round_up(end_offset as usize, Self::CHUNK_SIZE); Self { start_blk_no, end_blk_no, blobs_at, max_read_size, - chunk_size, } } @@ -237,18 +234,12 @@ impl ChunkedVectoredReadBuilder { end_offset: u64, meta: BlobMeta, max_read_size: usize, - align: usize, ) -> Self { - Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align) + Self::new_impl(start_offset, end_offset, meta, Some(max_read_size)) } - pub(crate) fn new_streaming( - start_offset: u64, - end_offset: u64, - meta: BlobMeta, - align: usize, - ) -> Self { - Self::new_impl(start_offset, end_offset, meta, None, align) + pub(crate) fn new_streaming(start_offset: u64, end_offset: u64, meta: BlobMeta) -> Self { + Self::new_impl(start_offset, end_offset, meta, None) } /// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk. @@ -256,12 +247,12 @@ impl ChunkedVectoredReadBuilder { /// The resulting size also must be below the max read size. pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended { tracing::trace!(start, end, "trying to extend"); - let start_blk_no = start as usize / self.chunk_size; - let end_blk_no = div_round_up(end as usize, self.chunk_size); + let start_blk_no = start as usize / Self::CHUNK_SIZE; + let end_blk_no = div_round_up(end as usize, Self::CHUNK_SIZE); let not_limited_by_max_read_size = { if let Some(max_read_size) = self.max_read_size { - let coalesced_size = (end_blk_no - self.start_blk_no) * self.chunk_size; + let coalesced_size = (end_blk_no - self.start_blk_no) * Self::CHUNK_SIZE; coalesced_size <= max_read_size } else { true @@ -292,12 +283,12 @@ impl ChunkedVectoredReadBuilder { } pub(crate) fn size(&self) -> usize { - (self.end_blk_no - self.start_blk_no) * self.chunk_size + (self.end_blk_no - self.start_blk_no) * Self::CHUNK_SIZE } pub(crate) fn build(self) -> VectoredRead { - let start = (self.start_blk_no * self.chunk_size) as u64; - let end = (self.end_blk_no * self.chunk_size) as u64; + let start = (self.start_blk_no * Self::CHUNK_SIZE) as u64; + let end = (self.end_blk_no * Self::CHUNK_SIZE) as u64; VectoredRead { start, end, @@ -328,18 +319,14 @@ pub struct VectoredReadPlanner { prev: Option<(Key, Lsn, u64, BlobFlag)>, max_read_size: usize, - - align: usize, } impl VectoredReadPlanner { pub fn new(max_read_size: usize) -> Self { - let align = virtual_file::get_io_buffer_alignment(); Self { blobs: BTreeMap::new(), prev: None, max_read_size, - align, } } @@ -418,7 +405,6 @@ impl VectoredReadPlanner { end_offset, BlobMeta { key, lsn }, self.max_read_size, - self.align, ); let prev_read_builder = current_read_builder.replace(next_read_builder); @@ -472,13 +458,13 @@ impl<'a> VectoredBlobReader<'a> { ); if cfg!(debug_assertions) { - let align = virtual_file::get_io_buffer_alignment() as u64; + const ALIGN: u64 = virtual_file::get_io_buffer_alignment() as u64; debug_assert_eq!( - read.start % align, + read.start % ALIGN, 0, "Read start at {} does not satisfy the required io buffer alignment ({} bytes)", read.start, - align + ALIGN ); } @@ -553,22 +539,18 @@ pub struct StreamingVectoredReadPlanner { max_cnt: usize, /// Size of the current batch cnt: usize, - - align: usize, } impl StreamingVectoredReadPlanner { pub fn new(max_read_size: u64, max_cnt: usize) -> Self { assert!(max_cnt > 0); assert!(max_read_size > 0); - let align = virtual_file::get_io_buffer_alignment(); Self { read_builder: None, prev: None, max_cnt, max_read_size, cnt: 0, - align, } } @@ -621,7 +603,6 @@ impl StreamingVectoredReadPlanner { start_offset, end_offset, BlobMeta { key, lsn }, - self.align, )) }; } @@ -656,9 +637,9 @@ mod tests { use super::*; fn validate_read(read: &VectoredRead, offset_range: &[(Key, Lsn, u64, BlobFlag)]) { - let align = virtual_file::get_io_buffer_alignment() as u64; - assert_eq!(read.start % align, 0); - assert_eq!(read.start / align, offset_range.first().unwrap().2 / align); + const ALIGN: u64 = virtual_file::get_io_buffer_alignment() as u64; + assert_eq!(read.start % ALIGN, 0); + assert_eq!(read.start / ALIGN, offset_range.first().unwrap().2 / ALIGN); let expected_offsets_in_read: Vec<_> = offset_range.iter().map(|o| o.2).collect(); @@ -676,32 +657,27 @@ mod tests { fn planner_chunked_coalesce_all_test() { use crate::virtual_file; - let chunk_size = virtual_file::get_io_buffer_alignment() as u64; + const CHUNK_SIZE: u64 = virtual_file::get_io_buffer_alignment() as u64; - // The test explicitly does not check chunk size < 512 - if chunk_size < 512 { - return; - } - - let max_read_size = chunk_size as usize * 8; + let max_read_size = CHUNK_SIZE as usize * 8; let key = Key::MIN; let lsn = Lsn(0); let blob_descriptions = [ - (key, lsn, chunk_size / 8, BlobFlag::None), // Read 1 BEGIN - (key, lsn, chunk_size / 4, BlobFlag::Ignore), // Gap - (key, lsn, chunk_size / 2, BlobFlag::None), - (key, lsn, chunk_size - 2, BlobFlag::Ignore), // Gap - (key, lsn, chunk_size, BlobFlag::None), - (key, lsn, chunk_size * 2 - 1, BlobFlag::None), - (key, lsn, chunk_size * 2 + 1, BlobFlag::Ignore), // Gap - (key, lsn, chunk_size * 3 + 1, BlobFlag::None), - (key, lsn, chunk_size * 5 + 1, BlobFlag::None), - (key, lsn, chunk_size * 6 + 1, BlobFlag::Ignore), // skipped chunk size, but not a chunk: should coalesce. - (key, lsn, chunk_size * 7 + 1, BlobFlag::None), - (key, lsn, chunk_size * 8, BlobFlag::None), // Read 2 BEGIN (b/c max_read_size) - (key, lsn, chunk_size * 9, BlobFlag::Ignore), // ==== skipped a chunk - (key, lsn, chunk_size * 10, BlobFlag::None), // Read 3 BEGIN (cannot coalesce) + (key, lsn, CHUNK_SIZE / 8, BlobFlag::None), // Read 1 BEGIN + (key, lsn, CHUNK_SIZE / 4, BlobFlag::Ignore), // Gap + (key, lsn, CHUNK_SIZE / 2, BlobFlag::None), + (key, lsn, CHUNK_SIZE - 2, BlobFlag::Ignore), // Gap + (key, lsn, CHUNK_SIZE, BlobFlag::None), + (key, lsn, CHUNK_SIZE * 2 - 1, BlobFlag::None), + (key, lsn, CHUNK_SIZE * 2 + 1, BlobFlag::Ignore), // Gap + (key, lsn, CHUNK_SIZE * 3 + 1, BlobFlag::None), + (key, lsn, CHUNK_SIZE * 5 + 1, BlobFlag::None), + (key, lsn, CHUNK_SIZE * 6 + 1, BlobFlag::Ignore), // skipped chunk size, but not a chunk: should coalesce. + (key, lsn, CHUNK_SIZE * 7 + 1, BlobFlag::None), + (key, lsn, CHUNK_SIZE * 8, BlobFlag::None), // Read 2 BEGIN (b/c max_read_size) + (key, lsn, CHUNK_SIZE * 9, BlobFlag::Ignore), // ==== skipped a chunk + (key, lsn, CHUNK_SIZE * 10, BlobFlag::None), // Read 3 BEGIN (cannot coalesce) ]; let ranges = [ @@ -780,19 +756,19 @@ mod tests { #[test] fn planner_replacement_test() { - let chunk_size = virtual_file::get_io_buffer_alignment() as u64; - let max_read_size = 128 * chunk_size as usize; + const CHUNK_SIZE: u64 = virtual_file::get_io_buffer_alignment() as u64; + let max_read_size = 128 * CHUNK_SIZE as usize; let first_key = Key::MIN; let second_key = first_key.next(); let lsn = Lsn(0); let blob_descriptions = vec![ (first_key, lsn, 0, BlobFlag::None), // First in read 1 - (first_key, lsn, chunk_size, BlobFlag::None), // Last in read 1 - (second_key, lsn, 2 * chunk_size, BlobFlag::ReplaceAll), - (second_key, lsn, 3 * chunk_size, BlobFlag::None), - (second_key, lsn, 4 * chunk_size, BlobFlag::ReplaceAll), // First in read 2 - (second_key, lsn, 5 * chunk_size, BlobFlag::None), // Last in read 2 + (first_key, lsn, CHUNK_SIZE, BlobFlag::None), // Last in read 1 + (second_key, lsn, 2 * CHUNK_SIZE, BlobFlag::ReplaceAll), + (second_key, lsn, 3 * CHUNK_SIZE, BlobFlag::None), + (second_key, lsn, 4 * CHUNK_SIZE, BlobFlag::ReplaceAll), // First in read 2 + (second_key, lsn, 5 * CHUNK_SIZE, BlobFlag::None), // Last in read 2 ]; let ranges = [&blob_descriptions[0..2], &blob_descriptions[4..]]; @@ -802,7 +778,7 @@ mod tests { planner.handle(key, lsn, offset, flag); } - planner.handle_range_end(6 * chunk_size); + planner.handle_range_end(6 * CHUNK_SIZE); let reads = planner.finish(); assert_eq!(reads.len(), 2); @@ -947,7 +923,6 @@ mod tests { let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16; let mut buf = BytesMut::with_capacity(reserved_bytes); - let align = virtual_file::get_io_buffer_alignment(); let vectored_blob_reader = VectoredBlobReader::new(&file); let meta = BlobMeta { key: Key::MIN, @@ -959,8 +934,7 @@ mod tests { if idx + 1 == offsets.len() { continue; } - let read_builder = - ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align); + let read_builder = ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096); let read = read_builder.build(); let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?; assert_eq!(result.blobs.len(), 1); diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 5b7b279888..d260116b38 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -23,10 +23,12 @@ use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT; use pageserver_api::shard::TenantShardId; use std::fs::File; use std::io::{Error, ErrorKind, Seek, SeekFrom}; +#[cfg(target_os = "linux")] +use std::os::unix::fs::OpenOptionsExt; use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use tokio::time::Instant; @@ -38,7 +40,7 @@ pub use io_engine::FeatureTestResult as IoEngineFeatureTestResult; mod metadata; mod open_options; use self::owned_buffers_io::write::OwnedAsyncWriter; -pub(crate) use api::DirectIoMode; +pub(crate) use api::IoMode; pub(crate) use io_engine::IoEngineKind; pub(crate) use metadata::Metadata; pub(crate) use open_options::*; @@ -61,6 +63,171 @@ pub(crate) mod owned_buffers_io { } } +#[derive(Debug)] +pub struct VirtualFile { + inner: VirtualFileInner, + _mode: IoMode, +} + +impl VirtualFile { + /// Open a file in read-only mode. Like File::open. + pub async fn open>( + path: P, + ctx: &RequestContext, + ) -> Result { + let inner = VirtualFileInner::open(path, ctx).await?; + Ok(VirtualFile { + inner, + _mode: IoMode::Buffered, + }) + } + + /// Open a file in read-only mode. Like File::open. + /// + /// `O_DIRECT` will be enabled base on `virtual_file_io_mode`. + pub async fn open_v2>( + path: P, + ctx: &RequestContext, + ) -> Result { + Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await + } + + pub async fn create>( + path: P, + ctx: &RequestContext, + ) -> Result { + let inner = VirtualFileInner::create(path, ctx).await?; + Ok(VirtualFile { + inner, + _mode: IoMode::Buffered, + }) + } + + pub async fn create_v2>( + path: P, + ctx: &RequestContext, + ) -> Result { + VirtualFile::open_with_options_v2( + path.as_ref(), + OpenOptions::new().write(true).create(true).truncate(true), + ctx, + ) + .await + } + + pub async fn open_with_options>( + path: P, + open_options: &OpenOptions, + ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */ + ) -> Result { + let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?; + Ok(VirtualFile { + inner, + _mode: IoMode::Buffered, + }) + } + + pub async fn open_with_options_v2>( + path: P, + open_options: &OpenOptions, + ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */ + ) -> Result { + let file = match get_io_mode() { + IoMode::Buffered => { + let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?; + VirtualFile { + inner, + _mode: IoMode::Buffered, + } + } + #[cfg(target_os = "linux")] + IoMode::Direct => { + let inner = VirtualFileInner::open_with_options( + path, + open_options.clone().custom_flags(nix::libc::O_DIRECT), + ctx, + ) + .await?; + VirtualFile { + inner, + _mode: IoMode::Direct, + } + } + }; + Ok(file) + } + + pub fn path(&self) -> &Utf8Path { + self.inner.path.as_path() + } + + pub async fn crashsafe_overwrite + Send, Buf: IoBuf + Send>( + final_path: Utf8PathBuf, + tmp_path: Utf8PathBuf, + content: B, + ) -> std::io::Result<()> { + VirtualFileInner::crashsafe_overwrite(final_path, tmp_path, content).await + } + + pub async fn sync_all(&self) -> Result<(), Error> { + self.inner.sync_all().await + } + + pub async fn sync_data(&self) -> Result<(), Error> { + self.inner.sync_data().await + } + + pub async fn metadata(&self) -> Result { + self.inner.metadata().await + } + + pub fn remove(self) { + self.inner.remove(); + } + + pub async fn seek(&mut self, pos: SeekFrom) -> Result { + self.inner.seek(pos).await + } + + pub async fn read_exact_at( + &self, + slice: Slice, + offset: u64, + ctx: &RequestContext, + ) -> Result, Error> + where + Buf: IoBufMut + Send, + { + self.inner.read_exact_at(slice, offset, ctx).await + } + + pub async fn read_exact_at_page( + &self, + page: PageWriteGuard<'static>, + offset: u64, + ctx: &RequestContext, + ) -> Result, Error> { + self.inner.read_exact_at_page(page, offset, ctx).await + } + + pub async fn write_all_at( + &self, + buf: FullSlice, + offset: u64, + ctx: &RequestContext, + ) -> (FullSlice, Result<(), Error>) { + self.inner.write_all_at(buf, offset, ctx).await + } + + pub async fn write_all( + &mut self, + buf: FullSlice, + ctx: &RequestContext, + ) -> (FullSlice, Result) { + self.inner.write_all(buf, ctx).await + } +} + /// /// A virtual file descriptor. You can use this just like std::fs::File, but internally /// the underlying file is closed if the system is low on file descriptors, @@ -77,7 +244,7 @@ pub(crate) mod owned_buffers_io { /// 'tag' field is used to detect whether the handle still is valid or not. /// #[derive(Debug)] -pub struct VirtualFile { +pub struct VirtualFileInner { /// Lazy handle to the global file descriptor cache. The slot that this points to /// might contain our File, or it may be empty, or it may contain a File that /// belongs to a different VirtualFile. @@ -350,12 +517,12 @@ macro_rules! with_file { }}; } -impl VirtualFile { +impl VirtualFileInner { /// Open a file in read-only mode. Like File::open. pub async fn open>( path: P, ctx: &RequestContext, - ) -> Result { + ) -> Result { Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await } @@ -364,7 +531,7 @@ impl VirtualFile { pub async fn create>( path: P, ctx: &RequestContext, - ) -> Result { + ) -> Result { Self::open_with_options( path.as_ref(), OpenOptions::new().write(true).create(true).truncate(true), @@ -382,7 +549,7 @@ impl VirtualFile { path: P, open_options: &OpenOptions, _ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */ - ) -> Result { + ) -> Result { let path_ref = path.as_ref(); let path_str = path_ref.to_string(); let parts = path_str.split('/').collect::>(); @@ -423,7 +590,7 @@ impl VirtualFile { reopen_options.create_new(false); reopen_options.truncate(false); - let vfile = VirtualFile { + let vfile = VirtualFileInner { handle: RwLock::new(handle), pos: 0, path: path_ref.to_path_buf(), @@ -1034,6 +1201,21 @@ impl tokio_epoll_uring::IoFd for FileGuard { #[cfg(test)] impl VirtualFile { + pub(crate) async fn read_blk( + &self, + blknum: u32, + ctx: &RequestContext, + ) -> Result, std::io::Error> { + self.inner.read_blk(blknum, ctx).await + } + + async fn read_to_end(&mut self, buf: &mut Vec, ctx: &RequestContext) -> Result<(), Error> { + self.inner.read_to_end(buf, ctx).await + } +} + +#[cfg(test)] +impl VirtualFileInner { pub(crate) async fn read_blk( &self, blknum: u32, @@ -1067,7 +1249,7 @@ impl VirtualFile { } } -impl Drop for VirtualFile { +impl Drop for VirtualFileInner { /// If a VirtualFile is dropped, close the underlying file if it was open. fn drop(&mut self) { let handle = self.handle.get_mut(); @@ -1143,15 +1325,10 @@ impl OpenFiles { /// server startup. /// #[cfg(not(test))] -pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize) { +pub fn init(num_slots: usize, engine: IoEngineKind) { if OPEN_FILES.set(OpenFiles::new(num_slots)).is_err() { panic!("virtual_file::init called twice"); } - if set_io_buffer_alignment(io_buffer_alignment).is_err() { - panic!( - "IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}" - ); - } io_engine::init(engine); crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64); } @@ -1175,47 +1352,20 @@ fn get_open_files() -> &'static OpenFiles { } } -static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT); - -/// Returns true if the alignment is a power of two and is greater or equal to 512. -fn is_valid_io_buffer_alignment(align: usize) -> bool { - align.is_power_of_two() && align >= 512 -} - -/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is -/// not a power of two or less than 512 bytes. -#[allow(unused)] -pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> { - if is_valid_io_buffer_alignment(align) { - IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed); - Ok(()) - } else { - Err(align) - } -} - /// Gets the io buffer alignment. -/// -/// This function should be used for getting the actual alignment value to use. -pub(crate) fn get_io_buffer_alignment() -> usize { - let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed); - - if cfg!(test) { - let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT"; - if let Some(test_align) = utils::env::var(env_var_name) { - if is_valid_io_buffer_alignment(test_align) { - test_align - } else { - panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}"); - } - } else { - align - } - } else { - align - } +pub(crate) const fn get_io_buffer_alignment() -> usize { + DEFAULT_IO_BUFFER_ALIGNMENT } +static IO_MODE: AtomicU8 = AtomicU8::new(IoMode::preferred() as u8); + +pub(crate) fn set_io_mode(mode: IoMode) { + IO_MODE.store(mode as u8, std::sync::atomic::Ordering::Relaxed); +} + +pub(crate) fn get_io_mode() -> IoMode { + IoMode::try_from(IO_MODE.load(Ordering::Relaxed)).unwrap() +} #[cfg(test)] mod tests { use crate::context::DownloadBehavior; @@ -1524,7 +1674,7 @@ mod tests { // Open the file many times. let mut files = Vec::new(); for _ in 0..VIRTUAL_FILES { - let f = VirtualFile::open_with_options( + let f = VirtualFileInner::open_with_options( &test_file_path, OpenOptions::new().read(true), &ctx, @@ -1576,7 +1726,7 @@ mod tests { let path = testdir.join("myfile"); let tmp_path = testdir.join("myfile.tmp"); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) + VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap()); @@ -1585,7 +1735,7 @@ mod tests { assert!(!tmp_path.exists()); drop(file); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec()) + VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap()); @@ -1608,7 +1758,7 @@ mod tests { std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap(); assert!(tmp_path.exists()); - VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) + VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec()) .await .unwrap(); diff --git a/pgxn/neon/control_plane_connector.c b/pgxn/neon/control_plane_connector.c index de023da5c4..0730c305cb 100644 --- a/pgxn/neon/control_plane_connector.c +++ b/pgxn/neon/control_plane_connector.c @@ -146,6 +146,8 @@ ConstructDeltaMessage() if (RootTable.role_table) { JsonbValue roles; + HASH_SEQ_STATUS status; + RoleEntry *entry; roles.type = jbvString; roles.val.string.val = "roles"; @@ -153,9 +155,6 @@ ConstructDeltaMessage() pushJsonbValue(&state, WJB_KEY, &roles); pushJsonbValue(&state, WJB_BEGIN_ARRAY, NULL); - HASH_SEQ_STATUS status; - RoleEntry *entry; - hash_seq_init(&status, RootTable.role_table); while ((entry = hash_seq_search(&status)) != NULL) { @@ -190,10 +189,12 @@ ConstructDeltaMessage() } pushJsonbValue(&state, WJB_END_ARRAY, NULL); } - JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL); - Jsonb *jsonb = JsonbValueToJsonb(result); + { + JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL); + Jsonb *jsonb = JsonbValueToJsonb(result); - return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ ); + return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ ); + } } #define ERROR_SIZE 1024 @@ -272,32 +273,28 @@ SendDeltasToControlPlane() curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, ErrorWriteCallback); } - char *message = ConstructDeltaMessage(); - ErrorString str; - - str.size = 0; - - curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message); - curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str); - - const int num_retries = 5; - CURLcode curl_status; - - for (int i = 0; i < num_retries; i++) - { - if ((curl_status = curl_easy_perform(handle)) == 0) - break; - elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf); - pg_usleep(1000 * 1000); - } - if (curl_status != CURLE_OK) - { - elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf); - } - else { + char *message = ConstructDeltaMessage(); + ErrorString str; + const int num_retries = 5; + CURLcode curl_status; long response_code; + str.size = 0; + + curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message); + curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str); + + for (int i = 0; i < num_retries; i++) + { + if ((curl_status = curl_easy_perform(handle)) == 0) + break; + elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf); + pg_usleep(1000 * 1000); + } + if (curl_status != CURLE_OK) + elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf); + if (curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code) != CURLE_UNKNOWN_OPTION) { if (response_code != 200) @@ -376,10 +373,11 @@ MergeTable() if (old_table->db_table) { - InitDbTableIfNeeded(); DbEntry *entry; HASH_SEQ_STATUS status; + InitDbTableIfNeeded(); + hash_seq_init(&status, old_table->db_table); while ((entry = hash_seq_search(&status)) != NULL) { @@ -421,10 +419,11 @@ MergeTable() if (old_table->role_table) { - InitRoleTableIfNeeded(); RoleEntry *entry; HASH_SEQ_STATUS status; + InitRoleTableIfNeeded(); + hash_seq_init(&status, old_table->role_table); while ((entry = hash_seq_search(&status)) != NULL) { @@ -515,9 +514,12 @@ RoleIsNeonSuperuser(const char *role_name) static void HandleCreateDb(CreatedbStmt *stmt) { - InitDbTableIfNeeded(); DefElem *downer = NULL; ListCell *option; + bool found = false; + DbEntry *entry; + + InitDbTableIfNeeded(); foreach(option, stmt->options) { @@ -526,13 +528,11 @@ HandleCreateDb(CreatedbStmt *stmt) if (strcmp(defel->defname, "owner") == 0) downer = defel; } - bool found = false; - DbEntry *entry = hash_search( - CurrentDdlTable->db_table, - stmt->dbname, - HASH_ENTER, - &found); + entry = hash_search(CurrentDdlTable->db_table, + stmt->dbname, + HASH_ENTER, + &found); if (!found) memset(entry->old_name, 0, sizeof(entry->old_name)); @@ -554,21 +554,24 @@ HandleCreateDb(CreatedbStmt *stmt) static void HandleAlterOwner(AlterOwnerStmt *stmt) { + const char *name; + bool found = false; + DbEntry *entry; + const char *new_owner; + if (stmt->objectType != OBJECT_DATABASE) return; InitDbTableIfNeeded(); - const char *name = strVal(stmt->object); - bool found = false; - DbEntry *entry = hash_search( - CurrentDdlTable->db_table, - name, - HASH_ENTER, - &found); + name = strVal(stmt->object); + entry = hash_search(CurrentDdlTable->db_table, + name, + HASH_ENTER, + &found); if (!found) memset(entry->old_name, 0, sizeof(entry->old_name)); - const char *new_owner = get_rolespec_name(stmt->newowner); + new_owner = get_rolespec_name(stmt->newowner); if (RoleIsNeonSuperuser(new_owner)) elog(ERROR, "can't alter owner to neon_superuser"); entry->owner = get_role_oid(new_owner, false); @@ -578,21 +581,23 @@ HandleAlterOwner(AlterOwnerStmt *stmt) static void HandleDbRename(RenameStmt *stmt) { + bool found = false; + DbEntry *entry; + DbEntry *entry_for_new_name; + Assert(stmt->renameType == OBJECT_DATABASE); InitDbTableIfNeeded(); - bool found = false; - DbEntry *entry = hash_search( - CurrentDdlTable->db_table, - stmt->subname, - HASH_FIND, - &found); - DbEntry *entry_for_new_name = hash_search( - CurrentDdlTable->db_table, - stmt->newname, - HASH_ENTER, - NULL); + entry = hash_search(CurrentDdlTable->db_table, + stmt->subname, + HASH_FIND, + &found); + entry_for_new_name = hash_search(CurrentDdlTable->db_table, + stmt->newname, + HASH_ENTER, + NULL); entry_for_new_name->type = Op_Set; + if (found) { if (entry->old_name[0] != '\0') @@ -600,8 +605,7 @@ HandleDbRename(RenameStmt *stmt) else strlcpy(entry_for_new_name->old_name, entry->name, NAMEDATALEN); entry_for_new_name->owner = entry->owner; - hash_search( - CurrentDdlTable->db_table, + hash_search(CurrentDdlTable->db_table, stmt->subname, HASH_REMOVE, NULL); @@ -616,14 +620,15 @@ HandleDbRename(RenameStmt *stmt) static void HandleDropDb(DropdbStmt *stmt) { - InitDbTableIfNeeded(); bool found = false; - DbEntry *entry = hash_search( - CurrentDdlTable->db_table, - stmt->dbname, - HASH_ENTER, - &found); + DbEntry *entry; + InitDbTableIfNeeded(); + + entry = hash_search(CurrentDdlTable->db_table, + stmt->dbname, + HASH_ENTER, + &found); entry->type = Op_Delete; entry->owner = InvalidOid; if (!found) @@ -633,16 +638,14 @@ HandleDropDb(DropdbStmt *stmt) static void HandleCreateRole(CreateRoleStmt *stmt) { - InitRoleTableIfNeeded(); bool found = false; - RoleEntry *entry = hash_search( - CurrentDdlTable->role_table, - stmt->role, - HASH_ENTER, - &found); - DefElem *dpass = NULL; + RoleEntry *entry; + DefElem *dpass; ListCell *option; + InitRoleTableIfNeeded(); + + dpass = NULL; foreach(option, stmt->options) { DefElem *defel = lfirst(option); @@ -650,6 +653,11 @@ HandleCreateRole(CreateRoleStmt *stmt) if (strcmp(defel->defname, "password") == 0) dpass = defel; } + + entry = hash_search(CurrentDdlTable->role_table, + stmt->role, + HASH_ENTER, + &found); if (!found) memset(entry->old_name, 0, sizeof(entry->old_name)); if (dpass && dpass->arg) @@ -662,14 +670,18 @@ HandleCreateRole(CreateRoleStmt *stmt) static void HandleAlterRole(AlterRoleStmt *stmt) { - InitRoleTableIfNeeded(); - DefElem *dpass = NULL; - ListCell *option; const char *role_name = stmt->role->rolename; + DefElem *dpass; + ListCell *option; + bool found = false; + RoleEntry *entry; + + InitRoleTableIfNeeded(); if (RoleIsNeonSuperuser(role_name) && !superuser()) elog(ERROR, "can't ALTER neon_superuser"); + dpass = NULL; foreach(option, stmt->options) { DefElem *defel = lfirst(option); @@ -680,13 +692,11 @@ HandleAlterRole(AlterRoleStmt *stmt) /* We only care about updates to the password */ if (!dpass) return; - bool found = false; - RoleEntry *entry = hash_search( - CurrentDdlTable->role_table, - role_name, - HASH_ENTER, - &found); + entry = hash_search(CurrentDdlTable->role_table, + role_name, + HASH_ENTER, + &found); if (!found) memset(entry->old_name, 0, sizeof(entry->old_name)); if (dpass->arg) @@ -699,20 +709,22 @@ HandleAlterRole(AlterRoleStmt *stmt) static void HandleRoleRename(RenameStmt *stmt) { - InitRoleTableIfNeeded(); - Assert(stmt->renameType == OBJECT_ROLE); bool found = false; - RoleEntry *entry = hash_search( - CurrentDdlTable->role_table, - stmt->subname, - HASH_FIND, - &found); + RoleEntry *entry; + RoleEntry *entry_for_new_name; - RoleEntry *entry_for_new_name = hash_search( - CurrentDdlTable->role_table, - stmt->newname, - HASH_ENTER, - NULL); + Assert(stmt->renameType == OBJECT_ROLE); + InitRoleTableIfNeeded(); + + entry = hash_search(CurrentDdlTable->role_table, + stmt->subname, + HASH_FIND, + &found); + + entry_for_new_name = hash_search(CurrentDdlTable->role_table, + stmt->newname, + HASH_ENTER, + NULL); entry_for_new_name->type = Op_Set; if (found) @@ -738,9 +750,10 @@ HandleRoleRename(RenameStmt *stmt) static void HandleDropRole(DropRoleStmt *stmt) { - InitRoleTableIfNeeded(); ListCell *item; + InitRoleTableIfNeeded(); + foreach(item, stmt->roles) { RoleSpec *spec = lfirst(item); diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 892a272252..d789526050 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -170,12 +170,14 @@ lfc_disable(char const *op) if (lfc_desc > 0) { + int rc; + /* * If the reason of error is ENOSPC, then truncation of file may * help to reclaim some space */ pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_TRUNCATE); - int rc = ftruncate(lfc_desc, 0); + rc = ftruncate(lfc_desc, 0); pgstat_report_wait_end(); if (rc < 0) @@ -616,7 +618,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno) */ if (entry->bitmap[chunk_offs >> 5] == 0) { - bool has_remaining_pages; + bool has_remaining_pages = false; for (int i = 0; i < CHUNK_BITMAP_SIZE; i++) { @@ -666,7 +668,6 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, BufferTag tag; FileCacheEntry *entry; ssize_t rc; - bool result = true; uint32 hash; uint64 generation; uint32 entry_offset; @@ -925,10 +926,10 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, /* We can reuse a hole that was left behind when the LFC was shrunk previously */ FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->holes)); uint32 offset = hole->offset; - bool found; + bool hole_found; - hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &found); - CriticalAssert(found); + hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &hole_found); + CriticalAssert(hole_found); lfc_ctl->used += 1; entry->offset = offset; /* reuse the hole */ @@ -1004,7 +1005,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS) Datum result; HeapTuple tuple; char const *key; - uint64 value; + uint64 value = 0; Datum values[NUM_NEON_GET_STATS_COLS]; bool nulls[NUM_NEON_GET_STATS_COLS]; diff --git a/pgxn/neon/hll.c b/pgxn/neon/hll.c index f8496b3125..1f53c8fd36 100644 --- a/pgxn/neon/hll.c +++ b/pgxn/neon/hll.c @@ -116,8 +116,6 @@ addSHLL(HyperLogLogState *cState, uint32 hash) { uint8 count; uint32 index; - size_t i; - size_t j; TimestampTz now = GetCurrentTimestamp(); /* Use the first "k" (registerWidth) bits as a zero based index */ diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 0ca8a70d6d..b60ae41af3 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -89,7 +89,6 @@ typedef struct #if PG_VERSION_NUM >= 150000 static shmem_request_hook_type prev_shmem_request_hook = NULL; -static void walproposer_shmem_request(void); #endif static shmem_startup_hook_type prev_shmem_startup_hook; static PagestoreShmemState *pagestore_shared; @@ -441,8 +440,8 @@ pageserver_connect(shardno_t shard_no, int elevel) return false; } shard->state = PS_Connecting_Startup; - /* fallthrough */ } + /* FALLTHROUGH */ case PS_Connecting_Startup: { char *pagestream_query; @@ -453,8 +452,6 @@ pageserver_connect(shardno_t shard_no, int elevel) do { - WaitEvent event; - switch (poll_result) { default: /* unknown/unused states are handled as a failed connection */ @@ -585,8 +582,8 @@ pageserver_connect(shardno_t shard_no, int elevel) } shard->state = PS_Connecting_PageStream; - /* fallthrough */ } + /* FALLTHROUGH */ case PS_Connecting_PageStream: { neon_shard_log(shard_no, DEBUG5, "Connection state: Connecting_PageStream"); @@ -631,8 +628,8 @@ pageserver_connect(shardno_t shard_no, int elevel) } shard->state = PS_Connected; - /* fallthrough */ } + /* FALLTHROUGH */ case PS_Connected: /* * We successfully connected. Future connections to this PageServer diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index 9bce81bf2e..a497d387c8 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -94,7 +94,6 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters) metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t)); uint64 bucket_accum; int i = 0; - Datum getpage_wait_str; metrics[i].name = "getpage_wait_seconds_count"; metrics[i].is_bucket = false; @@ -224,7 +223,6 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; Datum values[3]; bool nulls[3]; - Datum getpage_wait_str; neon_per_backend_counters totals = {0}; metric_t *metrics; diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index e4754ec7ea..6b4b355672 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -7,6 +7,7 @@ #define NEON_PGVERSIONCOMPAT_H #include "fmgr.h" +#include "storage/buf_internals.h" #if PG_MAJORVERSION_NUM < 17 #define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId) @@ -20,11 +21,24 @@ NInfoGetRelNumber(a) == NInfoGetRelNumber(b) \ ) -/* buftag population & RelFileNode/RelFileLocator rework */ +/* These macros were turned into static inline functions in v16 */ #if PG_MAJORVERSION_NUM < 16 +static inline bool +BufferTagsEqual(const BufferTag *tag1, const BufferTag *tag2) +{ + return BUFFERTAGS_EQUAL(*tag1, *tag2); +} -#define InitBufferTag(tag, rfn, fn, bn) INIT_BUFFERTAG(*tag, *rfn, fn, bn) +static inline void +InitBufferTag(BufferTag *tag, const RelFileNode *rnode, + ForkNumber forkNum, BlockNumber blockNum) +{ + INIT_BUFFERTAG(*tag, *rnode, forkNum, blockNum); +} +#endif +/* RelFileNode -> RelFileLocator rework */ +#if PG_MAJORVERSION_NUM < 16 #define USE_RELFILENODE #define RELFILEINFO_HDR "storage/relfilenode.h" @@ -73,8 +87,6 @@ #define USE_RELFILELOCATOR -#define BUFFERTAGS_EQUAL(a, b) BufferTagsEqual(&(a), &(b)) - #define RELFILEINFO_HDR "storage/relfilelocator.h" #define NRelFileInfo RelFileLocator diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 4c9e40a063..f905e3b0fa 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -213,32 +213,6 @@ extern const f_smgr *smgr_neon(ProcNumber backend, NRelFileInfo rinfo); extern void smgr_init_neon(void); extern void readahead_buffer_resize(int newsize, void *extra); -/* Neon storage manager functionality */ - -extern void neon_init(void); -extern void neon_open(SMgrRelation reln); -extern void neon_close(SMgrRelation reln, ForkNumber forknum); -extern void neon_create(SMgrRelation reln, ForkNumber forknum, bool isRedo); -extern bool neon_exists(SMgrRelation reln, ForkNumber forknum); -extern void neon_unlink(NRelFileInfoBackend rnode, ForkNumber forknum, bool isRedo); -#if PG_MAJORVERSION_NUM < 16 -extern void neon_extend(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum, char *buffer, bool skipFsync); -#else -extern void neon_extend(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum, const void *buffer, bool skipFsync); -extern void neon_zeroextend(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum, int nbuffers, bool skipFsync); -#endif - -#if PG_MAJORVERSION_NUM >=17 -extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum, int nblocks); -#else -extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum); -#endif - /* * LSN values associated with each request to the pageserver */ @@ -278,13 +252,7 @@ extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum, extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum, BlockNumber blkno, neon_request_lsns request_lsns, void *buffer); #endif -extern void neon_writeback(SMgrRelation reln, ForkNumber forknum, - BlockNumber blocknum, BlockNumber nblocks); -extern BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum); extern int64 neon_dbsize(Oid dbNode); -extern void neon_truncate(SMgrRelation reln, ForkNumber forknum, - BlockNumber nblocks); -extern void neon_immedsync(SMgrRelation reln, ForkNumber forknum); /* utils for neon relsize cache */ extern void relsize_hash_init(void); diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 155756f8b3..3d9d9285df 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -118,6 +118,8 @@ static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS; static bool neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id); static bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) = NULL; +static BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum); + /* * Prefetch implementation: * @@ -215,7 +217,7 @@ typedef struct PrfHashEntry sizeof(BufferTag) \ ) -#define SH_EQUAL(tb, a, b) (BUFFERTAGS_EQUAL((a)->buftag, (b)->buftag)) +#define SH_EQUAL(tb, a, b) (BufferTagsEqual(&(a)->buftag, &(b)->buftag)) #define SH_SCOPE static inline #define SH_DEFINE #define SH_DECLARE @@ -736,7 +738,7 @@ static void prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns) { bool found; - uint64 mySlotNo = slot->my_ring_index; + uint64 mySlotNo PG_USED_FOR_ASSERTS_ONLY = slot->my_ring_index; NeonGetPageRequest request = { .req.tag = T_NeonGetPageRequest, @@ -803,15 +805,19 @@ prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns, bool is_prefetch) { uint64 min_ring_index; - PrefetchRequest req; + PrefetchRequest hashkey; #if USE_ASSERT_CHECKING bool any_hits = false; #endif /* We will never read further ahead than our buffer can store. */ nblocks = Max(1, Min(nblocks, readahead_buffer_size)); - /* use an intermediate PrefetchRequest struct to ensure correct alignment */ - req.buftag = tag; + /* + * Use an intermediate PrefetchRequest struct as the hash key to ensure + * correct alignment and that the padding bytes are cleared. + */ + memset(&hashkey.buftag, 0, sizeof(BufferTag)); + hashkey.buftag = tag; Retry: min_ring_index = UINT64_MAX; @@ -837,8 +843,8 @@ Retry: slot = NULL; entry = NULL; - req.buftag.blockNum = tag.blockNum + i; - entry = prfh_lookup(MyPState->prf_hash, (PrefetchRequest *) &req); + hashkey.buftag.blockNum = tag.blockNum + i; + entry = prfh_lookup(MyPState->prf_hash, &hashkey); if (entry != NULL) { @@ -849,7 +855,7 @@ Retry: Assert(slot->status != PRFS_UNUSED); Assert(MyPState->ring_last <= ring_index && ring_index < MyPState->ring_unused); - Assert(BUFFERTAGS_EQUAL(slot->buftag, req.buftag)); + Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag)); /* * If the caller specified a request LSN to use, only accept @@ -886,12 +892,19 @@ Retry: { min_ring_index = Min(min_ring_index, ring_index); /* The buffered request is good enough, return that index */ - pgBufferUsage.prefetch.duplicates++; + if (is_prefetch) + pgBufferUsage.prefetch.duplicates++; + else + pgBufferUsage.prefetch.hits++; continue; } } } - + else if (!is_prefetch) + { + pgBufferUsage.prefetch.misses += 1; + MyNeonCounters->getpage_prefetch_misses_total++; + } /* * We can only leave the block above by finding that there's * no entry that can satisfy this request, either because there @@ -974,7 +987,7 @@ Retry: * We must update the slot data before insertion, because the hash * function reads the buffer tag from the slot. */ - slot->buftag = req.buftag; + slot->buftag = hashkey.buftag; slot->shard_no = get_shard_number(&tag); slot->my_ring_index = ring_index; @@ -1452,7 +1465,6 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno, BlockNumber blknos[XLR_MAX_BLOCK_ID]; Page pageptrs[XLR_MAX_BLOCK_ID]; int nregistered = 0; - XLogRecPtr result = 0; for (int i = 0; i < nblocks; i++) { @@ -1765,7 +1777,7 @@ neon_wallog_page(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, co /* * neon_init() -- Initialize private state */ -void +static void neon_init(void) { Size prfs_size; @@ -2155,7 +2167,7 @@ neon_prefetch_response_usable(neon_request_lsns *request_lsns, /* * neon_exists() -- Does the physical file exist? */ -bool +static bool neon_exists(SMgrRelation reln, ForkNumber forkNum) { bool exists; @@ -2261,7 +2273,7 @@ neon_exists(SMgrRelation reln, ForkNumber forkNum) * * If isRedo is true, it's okay for the relation to exist already. */ -void +static void neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo) { switch (reln->smgr_relpersistence) @@ -2337,7 +2349,7 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo) * Note: any failure should be reported as WARNING not ERROR, because * we are usually not in a transaction anymore when this is called. */ -void +static void neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo) { /* @@ -2361,7 +2373,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo) * EOF). Note that we assume writing a block beyond current EOF * causes intervening file space to become filled with zeroes. */ -void +static void #if PG_MAJORVERSION_NUM < 16 neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, char *buffer, bool skipFsync) @@ -2453,7 +2465,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, } #if PG_MAJORVERSION_NUM >= 16 -void +static void neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum, int nblocks, bool skipFsync) { @@ -2549,7 +2561,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum, /* * neon_open() -- Initialize newly-opened relation. */ -void +static void neon_open(SMgrRelation reln) { /* @@ -2567,7 +2579,7 @@ neon_open(SMgrRelation reln) /* * neon_close() -- Close the specified relation, if it isn't closed already. */ -void +static void neon_close(SMgrRelation reln, ForkNumber forknum) { /* @@ -2582,13 +2594,12 @@ neon_close(SMgrRelation reln, ForkNumber forknum) /* * neon_prefetch() -- Initiate asynchronous read of the specified block of a relation */ -bool +static bool neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, int nblocks) { uint64 ring_index PG_USED_FOR_ASSERTS_ONLY; BufferTag tag; - bool io_initiated = false; switch (reln->smgr_relpersistence) { @@ -2612,7 +2623,6 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, while (nblocks > 0) { int iterblocks = Min(nblocks, PG_IOV_MAX); - int seqlen = 0; bits8 lfc_present[PG_IOV_MAX / 8]; memset(lfc_present, 0, sizeof(lfc_present)); @@ -2624,8 +2634,6 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, continue; } - io_initiated = true; - tag.blockNum = blocknum; for (int i = 0; i < PG_IOV_MAX / 8; i++) @@ -2648,7 +2656,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, /* * neon_prefetch() -- Initiate asynchronous read of the specified block of a relation */ -bool +static bool neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum) { uint64 ring_index PG_USED_FOR_ASSERTS_ONLY; @@ -2692,7 +2700,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum) * This accepts a range of blocks because flushing several pages at once is * considerably more efficient than doing so individually. */ -void +static void neon_writeback(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, BlockNumber nblocks) { @@ -2742,14 +2750,19 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block uint64 ring_index; PrfHashEntry *entry; PrefetchRequest *slot; - BufferTag buftag = {0}; + PrefetchRequest hashkey; Assert(PointerIsValid(request_lsns)); Assert(nblocks >= 1); - CopyNRelFileInfoToBufTag(buftag, rinfo); - buftag.forkNum = forkNum; - buftag.blockNum = base_blockno; + /* + * Use an intermediate PrefetchRequest struct as the hash key to ensure + * correct alignment and that the padding bytes are cleared. + */ + memset(&hashkey.buftag, 0, sizeof(BufferTag)); + CopyNRelFileInfoToBufTag(hashkey.buftag, rinfo); + hashkey.buftag.forkNum = forkNum; + hashkey.buftag.blockNum = base_blockno; /* * The redo process does not lock pages that it needs to replay but are @@ -2767,7 +2780,7 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block * weren't for the behaviour of the LwLsn cache that uses the highest * value of the LwLsn cache when the entry is not found. */ - prefetch_register_bufferv(buftag, request_lsns, nblocks, mask, false); + prefetch_register_bufferv(hashkey.buftag, request_lsns, nblocks, mask, false); for (int i = 0; i < nblocks; i++) { @@ -2788,8 +2801,8 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block * Try to find prefetched page in the list of received pages. */ Retry: - buftag.blockNum = blockno; - entry = prfh_lookup(MyPState->prf_hash, (PrefetchRequest *) &buftag); + hashkey.buftag.blockNum = blockno; + entry = prfh_lookup(MyPState->prf_hash, &hashkey); if (entry != NULL) { @@ -2797,7 +2810,6 @@ Retry: if (neon_prefetch_response_usable(reqlsns, slot)) { ring_index = slot->my_ring_index; - pgBufferUsage.prefetch.hits += 1; } else { @@ -2827,10 +2839,7 @@ Retry: { if (entry == NULL) { - pgBufferUsage.prefetch.misses += 1; - MyNeonCounters->getpage_prefetch_misses_total++; - - ring_index = prefetch_register_bufferv(buftag, reqlsns, 1, NULL, false); + ring_index = prefetch_register_bufferv(hashkey.buftag, reqlsns, 1, NULL, false); Assert(ring_index != UINT64_MAX); slot = GetPrfSlot(ring_index); } @@ -2855,8 +2864,8 @@ Retry: } while (!prefetch_wait_for(ring_index)); Assert(slot->status == PRFS_RECEIVED); - Assert(memcmp(&buftag, &slot->buftag, sizeof(BufferTag)) == 0); - Assert(buftag.blockNum == base_blockno + i); + Assert(memcmp(&hashkey.buftag, &slot->buftag, sizeof(BufferTag)) == 0); + Assert(hashkey.buftag.blockNum == base_blockno + i); resp = slot->response; @@ -2912,10 +2921,10 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, * neon_read() -- Read the specified block from a relation. */ #if PG_MAJORVERSION_NUM < 16 -void +static void neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, char *buffer) #else -void +static void neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer) #endif { @@ -3024,7 +3033,7 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer #endif /* PG_MAJORVERSION_NUM <= 16 */ #if PG_MAJORVERSION_NUM >= 17 -void +static void neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, void **buffers, BlockNumber nblocks) { @@ -3059,6 +3068,9 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers, nblocks, read); + if (lfc_result > 0) + MyNeonCounters->file_cache_hits_total += lfc_result; + /* Read all blocks from LFC, so we're done */ if (lfc_result == nblocks) return; @@ -3185,6 +3197,7 @@ hexdump_page(char *page) } #endif +#if PG_MAJORVERSION_NUM < 17 /* * neon_write() -- Write the supplied block at the appropriate location. * @@ -3192,7 +3205,7 @@ hexdump_page(char *page) * relation (ie, those before the current EOF). To extend a relation, * use mdextend(). */ -void +static void #if PG_MAJORVERSION_NUM < 16 neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, char *buffer, bool skipFsync) #else @@ -3258,11 +3271,12 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo #endif #endif } +#endif #if PG_MAJORVERSION_NUM >= 17 -void +static void neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno, const void **buffers, BlockNumber nblocks, bool skipFsync) { @@ -3312,7 +3326,7 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno, /* * neon_nblocks() -- Get the number of blocks stored in a relation. */ -BlockNumber +static BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum) { NeonResponse *resp; @@ -3449,7 +3463,7 @@ neon_dbsize(Oid dbNode) /* * neon_truncate() -- Truncate relation to specified number of blocks. */ -void +static void neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks) { XLogRecPtr lsn; @@ -3518,7 +3532,7 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks) * crash before the next checkpoint syncs the newly-inactive segment, that * segment may survive recovery, reintroducing unwanted data into the table. */ -void +static void neon_immedsync(SMgrRelation reln, ForkNumber forknum) { switch (reln->smgr_relpersistence) @@ -3548,8 +3562,8 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum) } #if PG_MAJORVERSION_NUM >= 17 -void -neon_regisersync(SMgrRelation reln, ForkNumber forknum) +static void +neon_registersync(SMgrRelation reln, ForkNumber forknum) { switch (reln->smgr_relpersistence) { @@ -3733,6 +3747,8 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf SlruKind kind; int n_blocks; shardno_t shard_no = 0; /* All SLRUs are at shard 0 */ + NeonResponse *resp; + NeonGetSlruSegmentRequest request; /* * Compute a request LSN to use, similar to neon_get_request_lsns() but the @@ -3771,8 +3787,7 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf else return -1; - NeonResponse *resp; - NeonGetSlruSegmentRequest request = { + request = (NeonGetSlruSegmentRequest) { .req.tag = T_NeonGetSlruSegmentRequest, .req.lsn = request_lsn, .req.not_modified_since = not_modified_since, @@ -3879,7 +3894,7 @@ static const struct f_smgr neon_smgr = .smgr_truncate = neon_truncate, .smgr_immedsync = neon_immedsync, #if PG_MAJORVERSION_NUM >= 17 - .smgr_registersync = neon_regisersync, + .smgr_registersync = neon_registersync, #endif .smgr_start_unlogged_build = neon_start_unlogged_build, .smgr_finish_unlogged_build_phase_1 = neon_finish_unlogged_build_phase_1, diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 78402a29d5..a3f33cb261 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -252,8 +252,6 @@ WalProposerPoll(WalProposer *wp) /* timeout expired: poll state */ if (rc == 0 || TimeToReconnect(wp, now) <= 0) { - TimestampTz now; - /* * If no WAL was generated during timeout (and we have already * collected the quorum), then send empty keepalive message @@ -269,8 +267,7 @@ WalProposerPoll(WalProposer *wp) now = wp->api.get_current_timestamp(wp); for (int i = 0; i < wp->n_safekeepers; i++) { - Safekeeper *sk = &wp->safekeeper[i]; - + sk = &wp->safekeeper[i]; if (TimestampDifferenceExceeds(sk->latestMsgReceivedAt, now, wp->config->safekeeper_connection_timeout)) { @@ -1080,7 +1077,7 @@ SendProposerElected(Safekeeper *sk) ProposerElected msg; TermHistory *th; term_t lastCommonTerm; - int i; + int idx; /* Now that we are ready to send it's a good moment to create WAL reader */ wp->api.wal_reader_allocate(sk); @@ -1099,15 +1096,15 @@ SendProposerElected(Safekeeper *sk) /* We must start somewhere. */ Assert(wp->propTermHistory.n_entries >= 1); - for (i = 0; i < Min(wp->propTermHistory.n_entries, th->n_entries); i++) + for (idx = 0; idx < Min(wp->propTermHistory.n_entries, th->n_entries); idx++) { - if (wp->propTermHistory.entries[i].term != th->entries[i].term) + if (wp->propTermHistory.entries[idx].term != th->entries[idx].term) break; /* term must begin everywhere at the same point */ - Assert(wp->propTermHistory.entries[i].lsn == th->entries[i].lsn); + Assert(wp->propTermHistory.entries[idx].lsn == th->entries[idx].lsn); } - i--; /* step back to the last common term */ - if (i < 0) + idx--; /* step back to the last common term */ + if (idx < 0) { /* safekeeper is empty or no common point, start from the beginning */ sk->startStreamingAt = wp->propTermHistory.entries[0].lsn; @@ -1128,14 +1125,14 @@ SendProposerElected(Safekeeper *sk) * proposer, LSN it is currently writing, but then we just pick * safekeeper pos as it obviously can't be higher. */ - if (wp->propTermHistory.entries[i].term == wp->propTerm) + if (wp->propTermHistory.entries[idx].term == wp->propTerm) { sk->startStreamingAt = sk->voteResponse.flushLsn; } else { - XLogRecPtr propEndLsn = wp->propTermHistory.entries[i + 1].lsn; - XLogRecPtr skEndLsn = (i + 1 < th->n_entries ? th->entries[i + 1].lsn : sk->voteResponse.flushLsn); + XLogRecPtr propEndLsn = wp->propTermHistory.entries[idx + 1].lsn; + XLogRecPtr skEndLsn = (idx + 1 < th->n_entries ? th->entries[idx + 1].lsn : sk->voteResponse.flushLsn); sk->startStreamingAt = Min(propEndLsn, skEndLsn); } @@ -1149,7 +1146,7 @@ SendProposerElected(Safekeeper *sk) msg.termHistory = &wp->propTermHistory; msg.timelineStartLsn = wp->timelineStartLsn; - lastCommonTerm = i >= 0 ? wp->propTermHistory.entries[i].term : 0; + lastCommonTerm = idx >= 0 ? wp->propTermHistory.entries[idx].term : 0; wp_log(LOG, "sending elected msg to node " UINT64_FORMAT " term=" UINT64_FORMAT ", startStreamingAt=%X/%X (lastCommonTerm=" UINT64_FORMAT "), termHistory.n_entries=%u to %s:%s, timelineStartLsn=%X/%X", sk->greetResponse.nodeId, msg.term, LSN_FORMAT_ARGS(msg.startStreamingAt), lastCommonTerm, msg.termHistory->n_entries, sk->host, sk->port, LSN_FORMAT_ARGS(msg.timelineStartLsn)); @@ -1641,7 +1638,7 @@ UpdateDonorShmem(WalProposer *wp) * Process AppendResponse message from safekeeper. */ static void -HandleSafekeeperResponse(WalProposer *wp, Safekeeper *sk) +HandleSafekeeperResponse(WalProposer *wp, Safekeeper *fromsk) { XLogRecPtr candidateTruncateLsn; XLogRecPtr newCommitLsn; @@ -1660,7 +1657,7 @@ HandleSafekeeperResponse(WalProposer *wp, Safekeeper *sk) * and WAL is committed by the quorum. BroadcastAppendRequest() should be * called to notify safekeepers about the new commitLsn. */ - wp->api.process_safekeeper_feedback(wp, sk); + wp->api.process_safekeeper_feedback(wp, fromsk); /* * Try to advance truncateLsn -- the last record flushed to all diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 41daeb87b9..d8c44f8182 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -725,7 +725,7 @@ extern void WalProposerBroadcast(WalProposer *wp, XLogRecPtr startpos, XLogRecPt extern void WalProposerPoll(WalProposer *wp); extern void WalProposerFree(WalProposer *wp); -extern WalproposerShmemState *GetWalpropShmemState(); +extern WalproposerShmemState *GetWalpropShmemState(void); /* * WaitEventSet API doesn't allow to remove socket, so walproposer_pg uses it to @@ -745,7 +745,7 @@ extern TimeLineID walprop_pg_get_timeline_id(void); * catch logging. */ #ifdef WALPROPOSER_LIB -extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...); +extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...) pg_attribute_printf(3, 4); #define wp_log(elevel, fmt, ...) WalProposerLibLog(wp, elevel, fmt, ## __VA_ARGS__) #else #define wp_log(elevel, fmt, ...) elog(elevel, WP_LOG_PREFIX fmt, ## __VA_ARGS__) diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 01f88a5ab3..706941c3f0 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -286,6 +286,9 @@ safekeepers_cmp(char *old, char *new) static void assign_neon_safekeepers(const char *newval, void *extra) { + char *newval_copy; + char *oldval; + if (!am_walproposer) return; @@ -295,8 +298,8 @@ assign_neon_safekeepers(const char *newval, void *extra) } /* Copy values because we will modify them in split_safekeepers_list() */ - char *newval_copy = pstrdup(newval); - char *oldval = pstrdup(wal_acceptors_list); + newval_copy = pstrdup(newval); + oldval = pstrdup(wal_acceptors_list); /* * TODO: restarting through FATAL is stupid and introduces 1s delay before @@ -538,7 +541,7 @@ nwp_shmem_startup_hook(void) } WalproposerShmemState * -GetWalpropShmemState() +GetWalpropShmemState(void) { Assert(walprop_shared != NULL); return walprop_shared; diff --git a/pgxn/neon_rmgr/neon_rmgr_desc.c b/pgxn/neon_rmgr/neon_rmgr_desc.c index 8901c85ba2..e8003a1066 100644 --- a/pgxn/neon_rmgr/neon_rmgr_desc.c +++ b/pgxn/neon_rmgr/neon_rmgr_desc.c @@ -44,27 +44,6 @@ infobits_desc(StringInfo buf, uint8 infobits, const char *keyname) appendStringInfoString(buf, "]"); } -static void -truncate_flags_desc(StringInfo buf, uint8 flags) -{ - appendStringInfoString(buf, "flags: ["); - - if (flags & XLH_TRUNCATE_CASCADE) - appendStringInfoString(buf, "CASCADE, "); - if (flags & XLH_TRUNCATE_RESTART_SEQS) - appendStringInfoString(buf, "RESTART_SEQS, "); - - if (buf->data[buf->len - 1] == ' ') - { - /* Truncate-away final unneeded ", " */ - Assert(buf->data[buf->len - 2] == ','); - buf->len -= 2; - buf->data[buf->len] = '\0'; - } - - appendStringInfoString(buf, "]"); -} - void neon_rm_desc(StringInfo buf, XLogReaderState *record) { diff --git a/pgxn/neon_walredo/walredoproc.c b/pgxn/neon_walredo/walredoproc.c index f98aa1cbe7..37abb3fa03 100644 --- a/pgxn/neon_walredo/walredoproc.c +++ b/pgxn/neon_walredo/walredoproc.c @@ -136,7 +136,7 @@ static bool redo_block_filter(XLogReaderState *record, uint8 block_id); static void GetPage(StringInfo input_message); static void Ping(StringInfo input_message); static ssize_t buffered_read(void *buf, size_t count); -static void CreateFakeSharedMemoryAndSemaphores(); +static void CreateFakeSharedMemoryAndSemaphores(void); static BufferTag target_redo_tag; @@ -170,6 +170,40 @@ close_range_syscall(unsigned int start_fd, unsigned int count, unsigned int flag return syscall(__NR_close_range, start_fd, count, flags); } + +static PgSeccompRule allowed_syscalls[] = +{ + /* Hard requirements */ + PG_SCMP_ALLOW(exit_group), + PG_SCMP_ALLOW(pselect6), + PG_SCMP_ALLOW(read), + PG_SCMP_ALLOW(select), + PG_SCMP_ALLOW(write), + + /* Memory allocation */ + PG_SCMP_ALLOW(brk), +#ifndef MALLOC_NO_MMAP + /* TODO: musl doesn't have mallopt */ + PG_SCMP_ALLOW(mmap), + PG_SCMP_ALLOW(munmap), +#endif + /* + * getpid() is called on assertion failure, in ExceptionalCondition. + * It's not really needed, but seems pointless to hide it either. The + * system call unlikely to expose a kernel vulnerability, and the PID + * is stored in MyProcPid anyway. + */ + PG_SCMP_ALLOW(getpid), + + /* Enable those for a proper shutdown. */ +#if 0 + PG_SCMP_ALLOW(munmap), + PG_SCMP_ALLOW(shmctl), + PG_SCMP_ALLOW(shmdt), + PG_SCMP_ALLOW(unlink), /* shm_unlink */ +#endif +}; + static void enter_seccomp_mode(void) { @@ -183,44 +217,12 @@ enter_seccomp_mode(void) (errcode(ERRCODE_SYSTEM_ERROR), errmsg("seccomp: could not close files >= fd 3"))); - PgSeccompRule syscalls[] = - { - /* Hard requirements */ - PG_SCMP_ALLOW(exit_group), - PG_SCMP_ALLOW(pselect6), - PG_SCMP_ALLOW(read), - PG_SCMP_ALLOW(select), - PG_SCMP_ALLOW(write), - - /* Memory allocation */ - PG_SCMP_ALLOW(brk), -#ifndef MALLOC_NO_MMAP - /* TODO: musl doesn't have mallopt */ - PG_SCMP_ALLOW(mmap), - PG_SCMP_ALLOW(munmap), -#endif - /* - * getpid() is called on assertion failure, in ExceptionalCondition. - * It's not really needed, but seems pointless to hide it either. The - * system call unlikely to expose a kernel vulnerability, and the PID - * is stored in MyProcPid anyway. - */ - PG_SCMP_ALLOW(getpid), - - /* Enable those for a proper shutdown. - PG_SCMP_ALLOW(munmap), - PG_SCMP_ALLOW(shmctl), - PG_SCMP_ALLOW(shmdt), - PG_SCMP_ALLOW(unlink), // shm_unlink - */ - }; - #ifdef MALLOC_NO_MMAP /* Ask glibc not to use mmap() */ mallopt(M_MMAP_MAX, 0); #endif - seccomp_load_rules(syscalls, lengthof(syscalls)); + seccomp_load_rules(allowed_syscalls, lengthof(allowed_syscalls)); } #endif /* HAVE_LIBSECCOMP */ @@ -449,7 +451,7 @@ WalRedoMain(int argc, char *argv[]) * half-initialized postgres. */ static void -CreateFakeSharedMemoryAndSemaphores() +CreateFakeSharedMemoryAndSemaphores(void) { PGShmemHeader *shim = NULL; PGShmemHeader *hdr; @@ -992,7 +994,7 @@ redo_block_filter(XLogReaderState *record, uint8 block_id) * If this block isn't one we are currently restoring, then return 'true' * so that this gets ignored */ - return !BUFFERTAGS_EQUAL(target_tag, target_redo_tag); + return !BufferTagsEqual(&target_tag, &target_redo_tag); } /* diff --git a/poetry.lock b/poetry.lock index 07f30d10e7..00fe2505c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2095,6 +2095,7 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -2103,6 +2104,8 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -2584,6 +2587,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2729,21 +2733,22 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "responses" -version = "0.21.0" +version = "0.25.3" description = "A utility library for mocking out the `requests` Python library." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "responses-0.21.0-py3-none-any.whl", hash = "sha256:2dcc863ba63963c0c3d9ee3fa9507cbe36b7d7b0fccb4f0bdfd9e96c539b1487"}, - {file = "responses-0.21.0.tar.gz", hash = "sha256:b82502eb5f09a0289d8e209e7bad71ef3978334f56d09b444253d5ad67bf5253"}, + {file = "responses-0.25.3-py3-none-any.whl", hash = "sha256:521efcbc82081ab8daa588e08f7e8a64ce79b91c39f6e62199b19159bea7dbcb"}, + {file = "responses-0.25.3.tar.gz", hash = "sha256:617b9247abd9ae28313d57a75880422d55ec63c29d33d629697590a034358dba"}, ] [package.dependencies] -requests = ">=2.0,<3.0" -urllib3 = ">=1.25.10" +pyyaml = "*" +requests = ">=2.30.0,<3.0" +urllib3 = ">=1.25.10,<3.0" [package.extras] -tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-localserver", "types-mock", "types-requests"] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"] [[package]] name = "rfc3339-validator" @@ -3137,6 +3142,16 @@ files = [ {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, + {file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"}, + {file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"}, + {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"}, + {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"}, + {file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"}, + {file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, diff --git a/pre-commit.py b/pre-commit.py index ae432e8225..c9567e0c50 100755 --- a/pre-commit.py +++ b/pre-commit.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import enum import os import subprocess import sys -from typing import List @enum.unique @@ -55,12 +56,12 @@ def mypy() -> str: return "poetry run mypy" -def get_commit_files() -> List[str]: +def get_commit_files() -> list[str]: files = subprocess.check_output("git diff --cached --name-only --diff-filter=ACM".split()) return files.decode().splitlines() -def check(name: str, suffix: str, cmd: str, changed_files: List[str], no_color: bool = False): +def check(name: str, suffix: str, cmd: str, changed_files: list[str], no_color: bool = False): print(f"Checking: {name} ", end="") applicable_files = list(filter(lambda fname: fname.strip().endswith(suffix), changed_files)) if not applicable_files: diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index ae9b2531aa..963fb94a7d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -39,7 +39,7 @@ http.workspace = true humantime.workspace = true humantime-serde.workspace = true hyper0.workspace = true -hyper1 = { package = "hyper", version = "1.2", features = ["server"] } +hyper = { workspace = true, features = ["server", "http1", "http2"] } hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } http-body-util = { version = "0.1" } indexmap.workspace = true @@ -77,7 +77,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres.workspace = true +tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } tokio-postgres-rustls.workspace = true tokio-rustls.workspace = true tokio-util.workspace = true @@ -101,7 +101,7 @@ jose-jwa = "0.1.2" jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] } signature = "2" ecdsa = "0.16" -p256 = "0.13" +p256 = { version = "0.13", features = ["jwk"] } rsa = "0.9" workspace_hack.workspace = true diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 285fa29428..94b84b6f00 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -3,8 +3,8 @@ use crate::{ auth::{self, backend::ComputeCredentialKeys, AuthFlow}, compute, config::AuthenticationConfig, - console::AuthSecret, context::RequestMonitoring, + control_plane::AuthSecret, sasl, stream::{PqStream, Stream}, }; diff --git a/proxy/src/auth/backend/web.rs b/proxy/src/auth/backend/console_redirect.rs similarity index 85% rename from proxy/src/auth/backend/web.rs rename to proxy/src/auth/backend/console_redirect.rs index 45710d244d..127be545e1 100644 --- a/proxy/src/auth/backend/web.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -1,8 +1,8 @@ use crate::{ auth, compute, config::AuthenticationConfig, - console::{self, provider::NodeInfo}, context::RequestMonitoring, + control_plane::{self, provider::NodeInfo}, error::{ReportableError, UserFacingError}, stream::PqStream, waiters, @@ -25,6 +25,10 @@ pub(crate) enum WebAuthError { Io(#[from] std::io::Error), } +pub struct ConsoleRedirectBackend { + console_uri: reqwest::Url, +} + impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() @@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub(super) async fn authenticate( +impl ConsoleRedirectBackend { + pub fn new(console_uri: reqwest::Url) -> Self { + Self { console_uri } + } + + pub(super) fn url(&self) -> &reqwest::Url { + &self.console_uri + } + + pub(crate) async fn authenticate( + &self, + ctx: &RequestMonitoring, + auth_config: &'static AuthenticationConfig, + client: &mut PqStream, + ) -> auth::Result { + authenticate(ctx, auth_config, &self.console_uri, client).await + } +} + +async fn authenticate( ctx: &RequestMonitoring, auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, @@ -70,7 +93,7 @@ pub(super) async fn authenticate( let (psql_session_id, waiter) = loop { let psql_session_id = new_psql_session_id(); - match console::mgmt::get_waiter(&psql_session_id) { + match control_plane::mgmt::get_waiter(&psql_session_id) { Ok(waiter) => break (psql_session_id, waiter), Err(_e) => continue, } diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 15123a2623..749218d260 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -2,8 +2,8 @@ use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint}; use crate::{ auth::{self, AuthFlow}, config::AuthenticationConfig, - console::AuthSecret, context::RequestMonitoring, + control_plane::AuthSecret, intern::EndpointIdInt, sasl, stream::{self, Stream}, diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index b62a11ccb2..17ab7eda22 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -17,6 +17,8 @@ use crate::{ RoleName, }; +use super::ComputeCredentialKeys; + // TODO(conrad): make these configurable. const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30); const MIN_RENEW: Duration = Duration::from_secs(30); @@ -241,7 +243,7 @@ impl JwkCacheEntryLock { endpoint: EndpointId, role_name: &RoleName, fetch: &F, - ) -> Result<(), anyhow::Error> { + ) -> Result { // JWT compact form is defined to be // || . || || . || // where Signature = alg( || . || ); @@ -300,9 +302,9 @@ impl JwkCacheEntryLock { key => bail!("unsupported key type {key:?}"), }; - let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) + let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; - let payload = serde_json::from_slice::>(&payload) + let payload = serde_json::from_slice::>(&payloadb) .context("Provided authentication token is not a valid JWT encoding")?; tracing::debug!(?payload, "JWT signature valid with claims"); @@ -327,7 +329,7 @@ impl JwkCacheEntryLock { ); } - Ok(()) + Ok(ComputeCredentialKeys::JwtPayload(payloadb)) } } @@ -339,7 +341,7 @@ impl JwkCache { role_name: &RoleName, fetch: &F, jwt: &str, - ) -> Result<(), anyhow::Error> { + ) -> Result { // try with just a read lock first let key = (endpoint.clone(), role_name.clone()); let entry = self.map.get(&key).as_deref().map(Arc::clone); @@ -571,7 +573,7 @@ mod tests { use bytes::Bytes; use http::Response; use http_body_util::Full; - use hyper1::service::service_fn; + use hyper::service::service_fn; use hyper_util::rt::TokioIo; use rand::rngs::OsRng; use rsa::pkcs8::DecodePrivateKey; @@ -736,7 +738,7 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL }); let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); - let server = hyper1::server::conn::http1::Builder::new(); + let server = hyper::server::conn::http1::Builder::new(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { loop { diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index f56b0a0a6d..12451847b1 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -5,11 +5,11 @@ use arc_swap::ArcSwapOption; use crate::{ compute::ConnCfg, - console::{ + context::RequestMonitoring, + control_plane::{ messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}, NodeInfo, }, - context::RequestMonitoring, intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag}, EndpointId, }; diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend/mod.rs similarity index 87% rename from proxy/src/auth/backend.rs rename to proxy/src/auth/backend/mod.rs index 0eeed27fb2..27c9f1876e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend/mod.rs @@ -1,27 +1,28 @@ mod classic; +mod console_redirect; mod hacks; pub mod jwt; pub mod local; -mod web; use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +pub use console_redirect::ConsoleRedirectBackend; +pub(crate) use console_redirect::WebAuthError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; use tracing::{info, warn}; -pub(crate) use web::WebAuthError; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::{validate_password_and_exchange, AuthError}; use crate::cache::Cached; -use crate::console::errors::GetAuthInfoError; -use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; -use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; +use crate::control_plane::errors::GetAuthInfoError; +use crate::control_plane::provider::{CachedRoleSecret, ControlPlaneBackend}; +use crate::control_plane::{AuthSecret, NodeInfo}; use crate::intern::EndpointIdInt; use crate::metrics::Metrics; use crate::proxy::connect_compute::ComputeConnectBackend; @@ -31,12 +32,12 @@ use crate::stream::Stream; use crate::{ auth::{self, ComputeUserInfoMaybeEndpoint}, config::AuthenticationConfig, - console::{ + control_plane::{ self, provider::{CachedAllowedIps, CachedNodeInfo}, Api, }, - stream, url, + stream, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; @@ -67,19 +68,19 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// backends which require them for the authentication process. pub enum Backend<'a, T, D> { /// Cloud API (V2). - Console(MaybeOwned<'a, ConsoleBackend>, T), + ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T), /// Authentication via a web browser. - Web(MaybeOwned<'a, url::ApiUrl>, D), + ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D), /// Local proxy uses configured auth credentials and does not wake compute Local(MaybeOwned<'a, LocalBackend>), } #[cfg(test)] pub(crate) trait TestBackend: Send + Sync + 'static { - fn wake_compute(&self) -> Result; + fn wake_compute(&self) -> Result; fn get_allowed_ips_and_secret( &self, - ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; + ) -> Result<(CachedAllowedIps, Option), control_plane::errors::GetAuthInfoError>; fn dyn_clone(&self) -> Box; } @@ -93,18 +94,23 @@ impl Clone for Box { impl std::fmt::Display for Backend<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Console(api, ()) => match &**api { - ConsoleBackend::Console(endpoint) => { - fmt.debug_tuple("Console").field(&endpoint.url()).finish() - } + Self::ControlPlane(api, ()) => match &**api { + ControlPlaneBackend::Management(endpoint) => fmt + .debug_tuple("ControlPlane::Management") + .field(&endpoint.url()) + .finish(), #[cfg(any(test, feature = "testing"))] - ConsoleBackend::Postgres(endpoint) => { - fmt.debug_tuple("Postgres").field(&endpoint.url()).finish() - } + ControlPlaneBackend::PostgresMock(endpoint) => fmt + .debug_tuple("ControlPlane::PostgresMock") + .field(&endpoint.url()) + .finish(), #[cfg(test)] - ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), + ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), }, - Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(), + Self::ConsoleRedirect(backend, ()) => fmt + .debug_tuple("ConsoleRedirect") + .field(&backend.url().as_str()) + .finish(), Self::Local(_) => fmt.debug_tuple("Local").finish(), } } @@ -115,8 +121,8 @@ impl Backend<'_, T, D> { /// This helps us pass structured config to async tasks. pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> { match self { - Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x), - Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x), + Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x), + Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(MaybeOwned::Borrowed(c), x), Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)), } } @@ -128,8 +134,8 @@ impl<'a, T, D> Backend<'a, T, D> { /// a function to a contained value. pub(crate) fn map(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> { match self { - Self::Console(c, x) => Backend::Console(c, f(x)), - Self::Web(c, x) => Backend::Web(c, x), + Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)), + Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(c, x), Self::Local(l) => Backend::Local(l), } } @@ -139,8 +145,8 @@ impl<'a, T, D, E> Backend<'a, Result, D> { /// This is most useful for error handling. pub(crate) fn transpose(self) -> Result, E> { match self { - Self::Console(c, x) => x.map(|x| Backend::Console(c, x)), - Self::Web(c, x) => Ok(Backend::Web(c, x)), + Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)), + Self::ConsoleRedirect(c, x) => Ok(Backend::ConsoleRedirect(c, x)), Self::Local(l) => Ok(Backend::Local(l)), } } @@ -170,10 +176,12 @@ impl ComputeUserInfo { } } +#[cfg_attr(test, derive(Debug))] pub(crate) enum ComputeCredentialKeys { #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), + JwtPayload(Vec), None, } @@ -234,7 +242,6 @@ impl AuthenticationConfig { pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, is_cleartext: bool, @@ -258,7 +265,7 @@ impl AuthenticationConfig { let limit_not_exceeded = self.rate_limiter.check( ( endpoint_int, - MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet), + MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), ), password_weight, ); @@ -290,7 +297,7 @@ impl AuthenticationConfig { /// All authentication flows will emit an AuthenticationOk message if successful. async fn auth_quirks( ctx: &RequestMonitoring, - api: &impl console::Api, + api: &impl control_plane::Api, user_info: ComputeUserInfoMaybeEndpoint, client: &mut stream::PqStream>, allow_cleartext: bool, @@ -332,7 +339,6 @@ async fn auth_quirks( let secret = if let Some(secret) = secret { config.check_rate_limit( ctx, - config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, @@ -412,8 +418,8 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get username from the credentials. pub(crate) fn get_user(&self) -> &str { match self { - Self::Console(_, user_info) => &user_info.user, - Self::Web(_, ()) => "web", + Self::ControlPlane(_, user_info) => &user_info.user, + Self::ConsoleRedirect(_, ()) => "web", Self::Local(_) => "local", } } @@ -429,7 +435,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { endpoint_rate_limiter: Arc, ) -> auth::Result> { let res = match self { - Self::Console(api, user_info) => { + Self::ControlPlane(api, user_info) => { info!( user = &*user_info.user, project = user_info.endpoint(), @@ -446,15 +452,15 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { endpoint_rate_limiter, ) .await?; - Backend::Console(api, credentials) + Backend::ControlPlane(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Self::Web(url, ()) => { + Self::ConsoleRedirect(backend, ()) => { info!("performing web authentication"); - let info = web::authenticate(ctx, config, &url, client).await?; + let info = backend.authenticate(ctx, config, client).await?; - Backend::Web(url, info) + Backend::ConsoleRedirect(backend, info) } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) @@ -472,8 +478,8 @@ impl Backend<'_, ComputeUserInfo, &()> { ctx: &RequestMonitoring, ) -> Result { match self { - Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Web(_, ()) => Ok(Cached::new_uncached(None)), + Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, + Self::ConsoleRedirect(_, ()) => Ok(Cached::new_uncached(None)), Self::Local(_) => Ok(Cached::new_uncached(None)), } } @@ -483,8 +489,10 @@ impl Backend<'_, ComputeUserInfo, &()> { ctx: &RequestMonitoring, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { match self { - Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + Self::ControlPlane(api, user_info) => { + api.get_allowed_ips_and_secret(ctx, user_info).await + } + Self::ConsoleRedirect(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), } } @@ -495,18 +503,18 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> { async fn wake_compute( &self, ctx: &RequestMonitoring, - ) -> Result { + ) -> Result { match self { - Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())), + Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await, + Self::ConsoleRedirect(_, info) => Ok(Cached::new_uncached(info.clone())), Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } fn get_keys(&self) -> &ComputeCredentialKeys { match self { - Self::Console(_, creds) => &creds.keys, - Self::Web(_, _) => &ComputeCredentialKeys::None, + Self::ControlPlane(_, creds) => &creds.keys, + Self::ConsoleRedirect(_, _) => &ComputeCredentialKeys::None, Self::Local(_) => &ComputeCredentialKeys::None, } } @@ -517,10 +525,10 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> { async fn wake_compute( &self, ctx: &RequestMonitoring, - ) -> Result { + ) -> Result { match self { - Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await, - Self::Web(_, ()) => { + Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await, + Self::ConsoleRedirect(_, ()) => { unreachable!("web auth flow doesn't support waking the compute") } Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), @@ -529,8 +537,8 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> { fn get_keys(&self) -> &ComputeCredentialKeys { match self { - Self::Console(_, creds) => &creds.keys, - Self::Web(_, ()) => &ComputeCredentialKeys::None, + Self::ControlPlane(_, creds) => &creds.keys, + Self::ConsoleRedirect(_, ()) => &ComputeCredentialKeys::None, Self::Local(_) => &ComputeCredentialKeys::None, } } @@ -553,12 +561,12 @@ mod tests { use crate::{ auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern}, config::AuthenticationConfig, - console::{ + context::RequestMonitoring, + control_plane::{ self, provider::{self, CachedAllowedIps, CachedRoleSecret}, CachedNodeInfo, }, - context::RequestMonitoring, proxy::NeonOptions, rate_limiter::{EndpointRateLimiter, RateBucketInfo}, scram::{threadpool::ThreadPool, ServerSecret}, @@ -572,12 +580,12 @@ mod tests { secret: AuthSecret, } - impl console::Api for Auth { + impl control_plane::Api for Auth { async fn get_role_secret( &self, _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, - ) -> Result { + ) -> Result { Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) } @@ -585,8 +593,10 @@ mod tests { &self, _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, - ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError> - { + ) -> Result< + (CachedAllowedIps, Option), + control_plane::errors::GetAuthInfoError, + > { Ok(( CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())), Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))), @@ -605,7 +615,7 @@ mod tests { &self, _ctx: &RequestMonitoring, _user_info: &super::ComputeUserInfo, - ) -> Result { + ) -> Result { unimplemented!() } } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index f7e2b5296e..9a5139dfb8 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -3,8 +3,8 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ config::TlsServerEndPoint, - console::AuthSecret, context::RequestMonitoring, + control_plane::AuthSecret, intern::EndpointIdInt, sasl, scram::{self, threadpool::ThreadPool}, diff --git a/proxy/src/auth.rs b/proxy/src/auth/mod.rs similarity index 98% rename from proxy/src/auth.rs rename to proxy/src/auth/mod.rs index 13639af3aa..0c8686add2 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth/mod.rs @@ -18,7 +18,7 @@ pub(crate) use flow::*; use tokio::time::error::Elapsed; use crate::{ - console, + control_plane, error::{ReportableError, UserFacingError}, }; use std::{io, net::IpAddr}; @@ -34,7 +34,7 @@ pub(crate) enum AuthErrorImpl { Web(#[from] backend::WebAuthError), #[error(transparent)] - GetAuthInfo(#[from] console::errors::GetAuthInfoError), + GetAuthInfo(#[from] control_plane::errors::GetAuthInfoError), /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index d5ce1e9273..c781af846a 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -6,13 +6,16 @@ use compute_api::spec::LocalProxySpec; use dashmap::DashMap; use futures::future::Either; use proxy::{ - auth::backend::{ - jwt::JwkCache, - local::{LocalBackend, JWKS_ROLE_MAP}, + auth::{ + self, + backend::{ + jwt::JwkCache, + local::{LocalBackend, JWKS_ROLE_MAP}, + }, }, cancellation::CancellationHandlerMain, config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, - console::{ + control_plane::{ locks::ApiLocks, messages::{EndpointJwksResponse, JwksSettings}, }, @@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> { let args = LocalProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; // before we bind to any ports, write the process ID to a file // so that compute-ctl can find our process later @@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> { let task = serverless::task_main( config, + auth_backend, http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( @@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, - auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), - )), metric_collection: None, allow_self_signed_compute: false, http_config, @@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }))) } +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &LocalProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { + let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( + LocalBackend::new(args.compute), + )); + + Ok(Box::leak(Box::new(auth_backend))) +} + async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { loop { rx.notified().await; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 0585902c3b..3f4c2df809 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -10,6 +10,7 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::AuthRateLimiter; +use proxy::auth::backend::ConsoleRedirectBackend; use proxy::auth::backend::MaybeOwned; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; @@ -19,8 +20,8 @@ use proxy::config::CacheOptions; use proxy::config::HttpConfig; use proxy::config::ProjectInfoCacheOptions; use proxy::config::ProxyProtocolV2; -use proxy::console; use proxy::context::parquet::ParquetUploadArgs; +use proxy::control_plane; use proxy::http; use proxy::http::health_server::AppMetrics; use proxy::metrics::Metrics; @@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> { let args = ProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; - info!("Authentication backend: {}", config.auth_backend); + info!("Authentication backend: {}", auth_backend); info!("Using region: {}", args.aws_region); let region_provider = @@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> { if let Some(proxy_listener) = proxy_listener { client_tasks.spawn(proxy::proxy::task_main( config, + auth_backend, proxy_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> { if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, + auth_backend, serverless_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -495,7 +499,7 @@ async fn main() -> anyhow::Result<()> { proxy: proxy::metrics::Metrics::get(), }, )); - maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener)); + maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. @@ -506,8 +510,8 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::Backend::Console(api, _) = &config.auth_backend { - if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { + if let auth::Backend::ControlPlane(api, _) = auth_backend { + if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} (client1, client2) => { @@ -610,73 +614,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { bail!("dynamic rate limiter should be disabled"); } - let auth_backend = match &args.auth_backend { - AuthBackendType::Console => { - let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; - let project_info_cache_config: ProjectInfoCacheOptions = - args.project_info_cache.parse()?; - let endpoint_cache_config: config::EndpointCacheConfig = - args.endpoint_cache_config.parse()?; - - info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); - info!( - "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" - ); - info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}"); - let caches = Box::leak(Box::new(console::caches::ApiCaches::new( - wake_compute_cache_config, - project_info_cache_config, - endpoint_cache_config, - ))); - - let config::ConcurrencyLockOptions { - shards, - limiter, - epoch, - timeout, - } = args.wake_compute_lock.parse()?; - info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)"); - let locks = Box::leak(Box::new(console::locks::ApiLocks::new( - "wake_compute_lock", - limiter, - shards, - timeout, - epoch, - &Metrics::get().wake_compute_lock, - )?)); - tokio::spawn(locks.garbage_collect_worker()); - - let url = args.auth_endpoint.parse()?; - let endpoint = http::Endpoint::new(url, http::new_client()); - - let mut wake_compute_rps_limit = args.wake_compute_limit.clone(); - RateBucketInfo::validate(&mut wake_compute_rps_limit)?; - let wake_compute_endpoint_rate_limiter = - Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); - let api = console::provider::neon::Api::new( - endpoint, - caches, - locks, - wake_compute_endpoint_rate_limiter, - ); - let api = console::provider::ConsoleBackend::Console(api); - auth::Backend::Console(MaybeOwned::Owned(api), ()) - } - - AuthBackendType::Web => { - let url = args.uri.parse()?; - auth::Backend::Web(MaybeOwned::Owned(url), ()) - } - - #[cfg(feature = "testing")] - AuthBackendType::Postgres => { - let url = args.auth_endpoint.parse()?; - let api = console::provider::mock::Api::new(url, !args.is_private_access_proxy); - let api = console::provider::ConsoleBackend::Postgres(api); - auth::Backend::Console(MaybeOwned::Owned(api), ()) - } - }; - let config::ConcurrencyLockOptions { shards, limiter, @@ -689,7 +626,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { ?epoch, "Using NodeLocks (connect_compute)" ); - let connect_compute_locks = console::locks::ApiLocks::new( + let connect_compute_locks = control_plane::locks::ApiLocks::new( "connect_compute_lock", limiter, shards, @@ -728,7 +665,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let config = Box::leak(Box::new(ProxyConfig { tls_config, - auth_backend, metric_collection, allow_self_signed_compute: args.allow_self_signed_compute, http_config, @@ -748,6 +684,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { Ok(config) } +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &ProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { + let auth_backend = match &args.auth_backend { + AuthBackendType::Console => { + let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; + let project_info_cache_config: ProjectInfoCacheOptions = + args.project_info_cache.parse()?; + let endpoint_cache_config: config::EndpointCacheConfig = + args.endpoint_cache_config.parse()?; + + info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); + info!( + "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" + ); + info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}"); + let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new( + wake_compute_cache_config, + project_info_cache_config, + endpoint_cache_config, + ))); + + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.wake_compute_lock.parse()?; + info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)"); + let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new( + "wake_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().wake_compute_lock, + )?)); + tokio::spawn(locks.garbage_collect_worker()); + + let url = args.auth_endpoint.parse()?; + let endpoint = http::Endpoint::new(url, http::new_client()); + + let mut wake_compute_rps_limit = args.wake_compute_limit.clone(); + RateBucketInfo::validate(&mut wake_compute_rps_limit)?; + let wake_compute_endpoint_rate_limiter = + Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); + let api = control_plane::provider::neon::Api::new( + endpoint, + caches, + locks, + wake_compute_endpoint_rate_limiter, + ); + let api = control_plane::provider::ControlPlaneBackend::Management(api); + auth::Backend::ControlPlane(MaybeOwned::Owned(api), ()) + } + + AuthBackendType::Web => { + let url = args.uri.parse()?; + auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ()) + } + + #[cfg(feature = "testing")] + AuthBackendType::Postgres => { + let url = args.auth_endpoint.parse()?; + let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy); + let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api); + auth::Backend::ControlPlane(MaybeOwned::Owned(api), ()) + } + }; + + Ok(Box::leak(Box::new(auth_backend))) +} + #[cfg(test)] mod tests { use std::time::Duration; diff --git a/proxy/src/cache.rs b/proxy/src/cache/mod.rs similarity index 100% rename from proxy/src/cache.rs rename to proxy/src/cache/mod.rs diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index ceae74a9a0..b92cedb043 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -16,7 +16,7 @@ use tracing::{debug, info}; use crate::{ auth::IpPattern, config::ProjectInfoCacheOptions, - console::AuthSecret, + control_plane::AuthSecret, intern::{EndpointIdInt, ProjectIdInt, RoleNameInt}, EndpointId, RoleName, }; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index e3d9ae530a..006804fcd4 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,8 +1,8 @@ use crate::{ auth::parse_endpoint_param, cancellation::CancelClosure, - console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError}, context::RequestMonitoring, + control_plane::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError}, error::{ReportableError, UserFacingError}, metrics::{Metrics, NumDbConnectionsGuard}, proxy::neon_option, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index e0d666adf7..c068fc50fb 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,9 +1,6 @@ use crate::{ - auth::{ - self, - backend::{jwt::JwkCache, AuthRateLimiter}, - }, - console::locks::ApiLocks, + auth::backend::{jwt::JwkCache, AuthRateLimiter}, + control_plane::locks::ApiLocks, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, scram::threadpool::ThreadPool, serverless::{cancel_set::CancelSet, GlobalConnPoolOptions}, @@ -29,7 +26,6 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, @@ -372,7 +368,7 @@ pub struct EndpointCacheConfig { } impl EndpointCacheConfig { - /// Default options for [`crate::console::provider::NodeInfoCache`]. + /// Default options for [`crate::control_plane::provider::NodeInfoCache`]. /// Notice that by default the limiter is empty, which means that cache is disabled. pub const CACHE_DEFAULT_OPTIONS: &'static str = "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s"; @@ -447,7 +443,7 @@ pub struct CacheOptions { } impl CacheOptions { - /// Default options for [`crate::console::provider::NodeInfoCache`]. + /// Default options for [`crate::control_plane::provider::NodeInfoCache`]. pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m"; /// Parse cache options passed via cmdline. @@ -503,7 +499,7 @@ pub struct ProjectInfoCacheOptions { } impl ProjectInfoCacheOptions { - /// Default options for [`crate::console::provider::NodeInfoCache`]. + /// Default options for [`crate::control_plane::provider::NodeInfoCache`]. pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=10000,ttl=4m,max_roles=10,gc_interval=60m"; @@ -622,9 +618,9 @@ pub struct ConcurrencyLockOptions { } impl ConcurrencyLockOptions { - /// Default options for [`crate::console::provider::ApiLocks`]. + /// Default options for [`crate::control_plane::provider::ApiLocks`]. pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0"; - /// Default options for [`crate::console::provider::ApiLocks`]. + /// Default options for [`crate::control_plane::provider::ApiLocks`]. pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str = "shards=64,permits=100,epoch=10m,timeout=10ms"; diff --git a/proxy/src/context.rs b/proxy/src/context/mod.rs similarity index 99% rename from proxy/src/context.rs rename to proxy/src/context/mod.rs index 021659e175..7fb4e7c698 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context/mod.rs @@ -11,7 +11,7 @@ use try_lock::TryLock; use uuid::Uuid; use crate::{ - console::messages::{ColdStartInfo, MetricsAuxInfo}, + control_plane::messages::{ColdStartInfo, MetricsAuxInfo}, error::ErrorKind, intern::{BranchIdInt, ProjectIdInt}, metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting}, diff --git a/proxy/src/console/messages.rs b/proxy/src/control_plane/messages.rs similarity index 99% rename from proxy/src/console/messages.rs rename to proxy/src/control_plane/messages.rs index 1696e229ce..960bb5bc21 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -10,14 +10,14 @@ use crate::proxy::retry::CouldRetry; /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize, Clone)] -pub(crate) struct ConsoleError { +pub(crate) struct ControlPlaneError { pub(crate) error: Box, #[serde(skip)] pub(crate) http_status_code: http::StatusCode, pub(crate) status: Option, } -impl ConsoleError { +impl ControlPlaneError { pub(crate) fn get_reason(&self) -> Reason { self.status .as_ref() @@ -51,7 +51,7 @@ impl ConsoleError { } } -impl Display for ConsoleError { +impl Display for ControlPlaneError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let msg: &str = self .status @@ -62,7 +62,7 @@ impl Display for ConsoleError { } } -impl CouldRetry for ConsoleError { +impl CouldRetry for ControlPlaneError { fn could_retry(&self) -> bool { // If the error message does not have a status, // the error is unknown and probably should not retry automatically diff --git a/proxy/src/console/mgmt.rs b/proxy/src/control_plane/mgmt.rs similarity index 98% rename from proxy/src/console/mgmt.rs rename to proxy/src/control_plane/mgmt.rs index ee5f83ee76..2c4b5a9b94 100644 --- a/proxy/src/console/mgmt.rs +++ b/proxy/src/control_plane/mgmt.rs @@ -1,5 +1,5 @@ use crate::{ - console::messages::{DatabaseInfo, KickSession}, + control_plane::messages::{DatabaseInfo, KickSession}, waiters::{self, Waiter, Waiters}, }; use anyhow::Context; diff --git a/proxy/src/console.rs b/proxy/src/control_plane/mod.rs similarity index 100% rename from proxy/src/console.rs rename to proxy/src/control_plane/mod.rs diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/control_plane/provider/mock.rs similarity index 98% rename from proxy/src/console/provider/mock.rs rename to proxy/src/control_plane/provider/mock.rs index b548a0203a..ea2eb79e2a 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/control_plane/provider/mock.rs @@ -10,7 +10,7 @@ use crate::{ use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use crate::{auth::IpPattern, cache::Cached}; use crate::{ - console::{ + control_plane::{ messages::MetricsAuxInfo, provider::{CachedAllowedIps, CachedRoleSecret}, }, @@ -166,7 +166,7 @@ impl Api { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), - cold_start_info: crate::console::messages::ColdStartInfo::Warm, + cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, allow_self_signed_compute: false, }; diff --git a/proxy/src/console/provider.rs b/proxy/src/control_plane/provider/mod.rs similarity index 90% rename from proxy/src/console/provider.rs rename to proxy/src/control_plane/provider/mod.rs index 95097f2de9..6cc525a324 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/control_plane/provider/mod.rs @@ -2,7 +2,7 @@ pub mod mock; pub mod neon; -use super::messages::{ConsoleError, MetricsAuxInfo}; +use super::messages::{ControlPlaneError, MetricsAuxInfo}; use crate::{ auth::{ backend::{ @@ -28,7 +28,7 @@ use tracing::info; pub(crate) mod errors { use crate::{ - console::messages::{self, ConsoleError, Reason}, + control_plane::messages::{self, ControlPlaneError, Reason}, error::{io_error, ErrorKind, ReportableError, UserFacingError}, proxy::retry::CouldRetry, }; @@ -44,7 +44,7 @@ pub(crate) mod errors { pub(crate) enum ApiError { /// Error returned by the console itself. #[error("{REQUEST_FAILED} with {0}")] - Console(ConsoleError), + ControlPlane(ControlPlaneError), /// Various IO errors like broken pipe or malformed payload. #[error("{REQUEST_FAILED}: {0}")] @@ -55,7 +55,7 @@ pub(crate) mod errors { /// Returns HTTP status code if it's the reason for failure. pub(crate) fn get_reason(&self) -> messages::Reason { match self { - ApiError::Console(e) => e.get_reason(), + ApiError::ControlPlane(e) => e.get_reason(), ApiError::Transport(_) => messages::Reason::Unknown, } } @@ -65,7 +65,7 @@ pub(crate) mod errors { fn to_string_client(&self) -> String { match self { // To minimize risks, only select errors are forwarded to users. - ApiError::Console(c) => c.get_user_facing_message(), + ApiError::ControlPlane(c) => c.get_user_facing_message(), ApiError::Transport(_) => REQUEST_FAILED.to_owned(), } } @@ -74,51 +74,51 @@ pub(crate) mod errors { impl ReportableError for ApiError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - ApiError::Console(e) => match e.get_reason() { + ApiError::ControlPlane(e) => match e.get_reason() { Reason::RoleProtected => ErrorKind::User, Reason::ResourceNotFound => ErrorKind::User, Reason::ProjectNotFound => ErrorKind::User, Reason::EndpointNotFound => ErrorKind::User, Reason::BranchNotFound => ErrorKind::User, Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, - Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User, - Reason::ActiveTimeQuotaExceeded => ErrorKind::User, - Reason::ComputeTimeQuotaExceeded => ErrorKind::User, - Reason::WrittenDataQuotaExceeded => ErrorKind::User, - Reason::DataTransferQuotaExceeded => ErrorKind::User, - Reason::LogicalSizeQuotaExceeded => ErrorKind::User, + Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota, + Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota, + Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota, + Reason::WrittenDataQuotaExceeded => ErrorKind::Quota, + Reason::DataTransferQuotaExceeded => ErrorKind::Quota, + Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota, Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane, Reason::LockAlreadyTaken => ErrorKind::ControlPlane, Reason::RunningOperations => ErrorKind::ControlPlane, Reason::Unknown => match &e { - ConsoleError { + ControlPlaneError { http_status_code: http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE, .. } => crate::error::ErrorKind::User, - ConsoleError { + ControlPlaneError { http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY, error, .. } if error .contains("compute time quota of non-primary branches is exceeded") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } - ConsoleError { + ControlPlaneError { http_status_code: http::StatusCode::LOCKED, error, .. } if error.contains("quota exceeded") || error.contains("the limit for current plan reached") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } - ConsoleError { + ControlPlaneError { http_status_code: http::StatusCode::TOO_MANY_REQUESTS, .. } => crate::error::ErrorKind::ServiceRateLimit, - ConsoleError { .. } => crate::error::ErrorKind::ControlPlane, + ControlPlaneError { .. } => crate::error::ErrorKind::ControlPlane, }, }, ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane, @@ -131,7 +131,7 @@ pub(crate) mod errors { match self { // retry some transport errors Self::Transport(io) => io.could_retry(), - Self::Console(e) => e.could_retry(), + Self::ControlPlane(e) => e.could_retry(), } } } @@ -309,12 +309,13 @@ impl NodeInfo { #[cfg(any(test, feature = "testing"))] ComputeCredentialKeys::Password(password) => self.config.password(password), ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), - ComputeCredentialKeys::None => &mut self.config, + ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config, }; } } -pub(crate) type NodeInfoCache = TimedLru>>; +pub(crate) type NodeInfoCache = + TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; @@ -353,28 +354,28 @@ pub(crate) trait Api { #[non_exhaustive] #[derive(Clone)] -pub enum ConsoleBackend { - /// Current Cloud API (V2). - Console(neon::Api), - /// Local mock of Cloud API (V2). +pub enum ControlPlaneBackend { + /// Current Management API (V2). + Management(neon::Api), + /// Local mock control plane. #[cfg(any(test, feature = "testing"))] - Postgres(mock::Api), + PostgresMock(mock::Api), /// Internal testing #[cfg(test)] #[allow(private_interfaces)] Test(Box), } -impl Api for ConsoleBackend { +impl Api for ControlPlaneBackend { async fn get_role_secret( &self, ctx: &RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { match self { - Self::Console(api) => api.get_role_secret(ctx, user_info).await, + Self::Management(api) => api.get_role_secret(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Self::Postgres(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, #[cfg(test)] Self::Test(_) => { unreachable!("this function should never be called in the test backend") @@ -388,9 +389,9 @@ impl Api for ConsoleBackend { user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), errors::GetAuthInfoError> { match self { - Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::Management(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await, #[cfg(test)] Self::Test(api) => api.get_allowed_ips_and_secret(), } @@ -402,9 +403,9 @@ impl Api for ConsoleBackend { endpoint: EndpointId, ) -> anyhow::Result> { match self { - Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await, + Self::Management(api) => api.get_endpoint_jwks(ctx, endpoint).await, #[cfg(any(test, feature = "testing"))] - Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await, + Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await, #[cfg(test)] Self::Test(_api) => Ok(vec![]), } @@ -416,16 +417,16 @@ impl Api for ConsoleBackend { user_info: &ComputeUserInfo, ) -> Result { match self { - Self::Console(api) => api.wake_compute(ctx, user_info).await, + Self::Management(api) => api.wake_compute(ctx, user_info).await, #[cfg(any(test, feature = "testing"))] - Self::Postgres(api) => api.wake_compute(ctx, user_info).await, + Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await, #[cfg(test)] Self::Test(api) => api.wake_compute(), } } } -/// Various caches for [`console`](super). +/// Various caches for [`control_plane`](super). pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub(crate) node_info: NodeInfoCache, @@ -454,7 +455,7 @@ impl ApiCaches { } } -/// Various caches for [`console`](super). +/// Various caches for [`control_plane`](super). pub struct ApiLocks { name: &'static str, node_locks: DashMap>, @@ -577,7 +578,7 @@ impl WakeComputePermit { } } -impl FetchAuthRules for ConsoleBackend { +impl FetchAuthRules for ControlPlaneBackend { async fn fetch_auth_rules( &self, ctx: &RequestMonitoring, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/control_plane/provider/neon.rs similarity index 97% rename from proxy/src/console/provider/neon.rs rename to proxy/src/control_plane/provider/neon.rs index 2d527f378c..d01878741c 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/control_plane/provider/neon.rs @@ -1,7 +1,7 @@ //! Production console backend. use super::{ - super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, + super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, @@ -9,7 +9,7 @@ use super::{ use crate::{ auth::backend::{jwt::AuthRule, ComputeUserInfo}, compute, - console::messages::{ColdStartInfo, EndpointJwksResponse, Reason}, + control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}, http, metrics::{CacheOutcome, Metrics}, rate_limiter::WakeComputeRateLimiter, @@ -22,7 +22,7 @@ use futures::TryFutureExt; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_postgres::config::SslMode; -use tracing::{debug, error, info, info_span, warn, Instrument}; +use tracing::{debug, info, info_span, warn, Instrument}; const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -348,7 +348,7 @@ impl super::Api for Api { let (cached, info) = cached.take_value(); let info = info.map_err(|c| { info!(key = &*key, "found cached wake_compute error"); - WakeComputeError::ApiError(ApiError::Console(*c)) + WakeComputeError::ApiError(ApiError::ControlPlane(*c)) })?; debug!(key = &*key, "found cached compute node info"); @@ -395,9 +395,9 @@ impl super::Api for Api { Ok(cached.map(|()| node)) } Err(err) => match err { - WakeComputeError::ApiError(ApiError::Console(err)) => { + WakeComputeError::ApiError(ApiError::ControlPlane(err)) => { let Some(status) = &err.status else { - return Err(WakeComputeError::ApiError(ApiError::Console(err))); + return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err))); }; let reason = status @@ -407,7 +407,7 @@ impl super::Api for Api { // if we can retry this error, do not cache it. if reason.can_retry() { - return Err(WakeComputeError::ApiError(ApiError::Console(err))); + return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err))); } // at this point, we should only have quota errors. @@ -422,7 +422,7 @@ impl super::Api for Api { Duration::from_secs(30), ); - Err(WakeComputeError::ApiError(ApiError::Console(err))) + Err(WakeComputeError::ApiError(ApiError::ControlPlane(err))) } err => return Err(err), }, @@ -448,7 +448,7 @@ async fn parse_body serde::Deserialize<'a>>( // as the fact that the request itself has failed. let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ConsoleError { + ControlPlaneError { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, @@ -456,8 +456,8 @@ async fn parse_body serde::Deserialize<'a>>( }); body.http_status_code = status; - error!("console responded with an error ({status}): {body:?}"); - Err(ApiError::Console(body)) + warn!("console responded with an error ({status}): {body:?}"); + Err(ApiError::ControlPlane(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 53f9f75c5b..1cd4dc2c22 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -49,6 +49,10 @@ pub enum ErrorKind { #[label(rename = "serviceratelimit")] ServiceRateLimit, + /// Proxy quota limit violation + #[label(rename = "quota")] + Quota, + /// internal errors Service, @@ -70,6 +74,7 @@ impl ErrorKind { ErrorKind::ClientDisconnect => "clientdisconnect", ErrorKind::RateLimit => "ratelimit", ErrorKind::ServiceRateLimit => "serviceratelimit", + ErrorKind::Quota => "quota", ErrorKind::Service => "service", ErrorKind::ControlPlane => "controlplane", ErrorKind::Postgres => "postgres", diff --git a/proxy/src/http/health_server.rs b/proxy/src/http/health_server.rs index cae9eb5b97..d0352351d5 100644 --- a/proxy/src/http/health_server.rs +++ b/proxy/src/http/health_server.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, bail}; -use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode}; +use hyper0::{header::CONTENT_TYPE, Body, Request, Response, StatusCode}; use measured::{text::BufferedTextEncoder, MetricGroup}; use metrics::NeonMetrics; use std::{ @@ -21,7 +21,7 @@ async fn status_handler(_: Request) -> Result, ApiError> { json_response(StatusCode::OK, "") } -fn make_router(metrics: AppMetrics) -> RouterBuilder { +fn make_router(metrics: AppMetrics) -> RouterBuilder { let state = Arc::new(Mutex::new(PrometheusHandler { encoder: BufferedTextEncoder::new(), metrics, @@ -45,7 +45,7 @@ pub async fn task_main( let service = || RouterService::new(make_router(metrics).build()?); - hyper::Server::from_tcp(http_listener)? + hyper0::Server::from_tcp(http_listener)? .serve(service().map_err(|e| anyhow!(e))?) .await?; diff --git a/proxy/src/http.rs b/proxy/src/http/mod.rs similarity index 99% rename from proxy/src/http.rs rename to proxy/src/http/mod.rs index 14720b5c6b..d8676d5b50 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http/mod.rs @@ -9,7 +9,7 @@ use std::time::Duration; use anyhow::bail; use bytes::Bytes; use http_body_util::BodyExt; -use hyper1::body::Body; +use hyper::body::Body; use serde::de::DeserializeOwned; pub(crate) use reqwest::{Request, Response}; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 92faab6167..8d274baa10 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -90,15 +90,13 @@ use tokio::task::JoinError; use tokio_util::sync::CancellationToken; use tracing::warn; -extern crate hyper0 as hyper; - pub mod auth; pub mod cache; pub mod cancellation; pub mod compute; pub mod config; -pub mod console; pub mod context; +pub mod control_plane; pub mod error; pub mod http; pub mod intern; diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index c2567e083a..272723a1bc 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -11,7 +11,7 @@ use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec}; use tokio::time::{self, Instant}; -use crate::console::messages::ColdStartInfo; +use crate::control_plane::messages::ColdStartInfo; #[derive(MetricGroup)] #[metric(new(thread_pool: Arc))] diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 3b6c467589..aac7720890 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -3,8 +3,8 @@ use crate::{ compute::COULD_NOT_CONNECT, compute::{self, PostgresConnection}, config::RetryConfig, - console::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + control_plane::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo}, error::ReportableError, metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType}, proxy::{ @@ -26,7 +26,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub(crate) fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -49,7 +49,7 @@ pub(crate) trait ConnectMechanism { async fn connect_once( &self, ctx: &RequestMonitoring, - node_info: &console::CachedNodeInfo, + node_info: &control_plane::CachedNodeInfo, timeout: time::Duration, ) -> Result; @@ -61,7 +61,7 @@ pub(crate) trait ComputeConnectBackend { async fn wake_compute( &self, ctx: &RequestMonitoring, - ) -> Result; + ) -> Result; fn get_keys(&self) -> &ComputeCredentialKeys; } @@ -84,7 +84,7 @@ impl ConnectMechanism for TcpMechanism<'_> { async fn connect_once( &self, ctx: &RequestMonitoring, - node_info: &console::CachedNodeInfo, + node_info: &control_plane::CachedNodeInfo, timeout: time::Duration, ) -> Result { let host = node_info.config.get_host()?; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy/mod.rs similarity index 95% rename from proxy/src/proxy.rs rename to proxy/src/proxy/mod.rs index 7003af2aba..3a43ccb74a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy/mod.rs @@ -35,7 +35,7 @@ use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, Instrument}; +use tracing::{error, info, warn, Instrument}; use self::{ connect_compute::{connect_to_compute, TcpMechanism}, @@ -61,6 +61,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -95,15 +96,15 @@ pub async fn task_main( connections.spawn(async move { let (socket, peer_addr) = match read_proxy_protocol(socket).await { Err(e) => { - error!("per-client task finished with an error: {e:#}"); + warn!("per-client task finished with an error: {e:#}"); return; } Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - error!("missing required proxy protocol header"); + warn!("missing required proxy protocol header"); return; } Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - error!("proxy protocol header not supported"); + warn!("proxy protocol header not supported"); return; } Ok((socket, Some(addr))) => (socket, addr.ip()), @@ -129,6 +130,7 @@ pub async fn task_main( let startup = Box::pin( handle_client( config, + auth_backend, &ctx, cancellation_handler, socket, @@ -144,7 +146,7 @@ pub async fn task_main( Err(e) => { // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - error!(parent: &span, "per-client task finished with an error: {e:#}"); + warn!(parent: &span, "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); @@ -155,7 +157,7 @@ pub async fn task_main( match p.proxy_pass().instrument(span.clone()).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); @@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, @@ -285,8 +289,7 @@ pub(crate) async fn handle_client( let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. - let result = config - .auth_backend + let result = auth_backend .as_ref() .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index c17108de0a..497cf4bfd5 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,7 +1,7 @@ use crate::{ cancellation, compute::PostgresConnection, - console::messages::MetricsAuxInfo, + control_plane::messages::MetricsAuxInfo, metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}, stream::Stream, usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}, @@ -71,7 +71,7 @@ impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { - tracing::error!(?err, "could not cancel the query in the database"); + tracing::warn!(?err, "could not cancel the query in the database"); } res } diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests/mod.rs similarity index 94% rename from proxy/src/proxy/tests.rs rename to proxy/src/proxy/tests/mod.rs index 058ec06e02..3861ddc8ed 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -11,9 +11,11 @@ use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, }; use crate::config::{CertResolver, RetryConfig}; -use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeInfoCache}; -use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status}; +use crate::control_plane::provider::{ + CachedAllowedIps, CachedRoleSecret, ControlPlaneBackend, NodeInfoCache, +}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; use crate::{sasl, scram, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; @@ -459,7 +461,7 @@ impl ConnectMechanism for TestConnectMechanism { async fn connect_once( &self, _ctx: &RequestMonitoring, - _node_info: &console::CachedNodeInfo, + _node_info: &control_plane::CachedNodeInfo, _timeout: std::time::Duration, ) -> Result { let mut counter = self.counter.lock().unwrap(); @@ -483,23 +485,23 @@ impl ConnectMechanism for TestConnectMechanism { } impl TestBackend for TestConnectMechanism { - fn wake_compute(&self) -> Result { + fn wake_compute(&self) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; *counter += 1; match action { ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { - let err = console::errors::ApiError::Console(ConsoleError { + let err = control_plane::errors::ApiError::ControlPlane(ControlPlaneError { http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: None, }); assert!(!err.could_retry()); - Err(console::errors::WakeComputeError::ApiError(err)) + Err(control_plane::errors::WakeComputeError::ApiError(err)) } ConnectAction::WakeRetry => { - let err = console::errors::ApiError::Console(ConsoleError { + let err = control_plane::errors::ApiError::ControlPlane(ControlPlaneError { http_status_code: StatusCode::BAD_REQUEST, error: "TEST".into(), status: Some(Status { @@ -507,13 +509,15 @@ impl TestBackend for TestConnectMechanism { message: "error".into(), details: Details { error_info: None, - retry_info: Some(console::messages::RetryInfo { retry_delay_ms: 1 }), + retry_info: Some(control_plane::messages::RetryInfo { + retry_delay_ms: 1, + }), user_facing_message: None, }, }), }); assert!(err.could_retry()); - Err(console::errors::WakeComputeError::ApiError(err)) + Err(control_plane::errors::WakeComputeError::ApiError(err)) } x => panic!("expecting action {x:?}, wake_compute is called instead"), } @@ -521,7 +525,7 @@ impl TestBackend for TestConnectMechanism { fn get_allowed_ips_and_secret( &self, - ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError> + ) -> Result<(CachedAllowedIps, Option), control_plane::errors::GetAuthInfoError> { unimplemented!("not used in tests") } @@ -538,7 +542,7 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), - cold_start_info: crate::console::messages::ColdStartInfo::Warm, + cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, allow_self_signed_compute: false, }; @@ -549,8 +553,8 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn fn helper_create_connect_info( mechanism: &TestConnectMechanism, ) -> auth::Backend<'static, ComputeCredentials, &()> { - let user_info = auth::Backend::Console( - MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), + let user_info = auth::Backend::ControlPlane( + MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))), ComputeCredentials { info: ComputeUserInfo { endpoint: "endpoint".into(), diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 9b8ac6d29d..ba674f5d0d 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,13 +1,13 @@ use crate::config::RetryConfig; -use crate::console::messages::{ConsoleError, Reason}; -use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; +use crate::control_plane::messages::{ControlPlaneError, Reason}; +use crate::control_plane::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::metrics::{ ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType, WakeupFailureKind, }; use crate::proxy::retry::{retry_after, should_retry}; -use hyper1::StatusCode; +use hyper::StatusCode; use tracing::{error, info, warn}; use super::connect_compute::ComputeConnectBackend; @@ -59,11 +59,11 @@ pub(crate) async fn wake_compute( } fn report_error(e: &WakeComputeError, retry: bool) { - use crate::console::errors::ApiError; + use crate::control_plane::errors::ApiError; let kind = match e { WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress, WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError, - WakeComputeError::ApiError(ApiError::Console(e)) => match e.get_reason() { + WakeComputeError::ApiError(ApiError::ControlPlane(e)) => match e.get_reason() { Reason::RoleProtected => WakeupFailureKind::ApiConsoleBadRequest, Reason::ResourceNotFound => WakeupFailureKind::ApiConsoleBadRequest, Reason::ProjectNotFound => WakeupFailureKind::ApiConsoleBadRequest, @@ -80,7 +80,7 @@ fn report_error(e: &WakeComputeError, retry: bool) { Reason::LockAlreadyTaken => WakeupFailureKind::ApiConsoleLocked, Reason::RunningOperations => WakeupFailureKind::ApiConsoleLocked, Reason::Unknown => match e { - ConsoleError { + ControlPlaneError { http_status_code: StatusCode::LOCKED, ref error, .. @@ -89,27 +89,27 @@ fn report_error(e: &WakeComputeError, retry: bool) { { WakeupFailureKind::QuotaExceeded } - ConsoleError { + ControlPlaneError { http_status_code: StatusCode::UNPROCESSABLE_ENTITY, ref error, .. } if error.contains("compute time quota of non-primary branches is exceeded") => { WakeupFailureKind::QuotaExceeded } - ConsoleError { + ControlPlaneError { http_status_code: StatusCode::LOCKED, .. } => WakeupFailureKind::ApiConsoleLocked, - ConsoleError { + ControlPlaneError { http_status_code: StatusCode::BAD_REQUEST, .. } => WakeupFailureKind::ApiConsoleBadRequest, - ConsoleError { + ControlPlaneError { http_status_code, .. } if http_status_code.is_server_error() => { WakeupFailureKind::ApiConsoleOtherServerError } - ConsoleError { .. } => WakeupFailureKind::ApiConsoleOtherError, + ControlPlaneError { .. } => WakeupFailureKind::ApiConsoleOtherError, }, }, WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked, diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter/mod.rs similarity index 100% rename from proxy/src/rate_limiter.rs rename to proxy/src/rate_limiter/mod.rs diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 2de66b58b1..ccd48f1481 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -6,7 +6,7 @@ use redis::{ ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, }; use tokio::task::JoinHandle; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use super::elasticache::CredentialsProvider; @@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider { return Ok(()); } Err(e) => { - error!("Error during PING: {e:?}"); + warn!("Error during PING: {e:?}"); } } } else { @@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider { info!("Connection succesfully established"); } Err(e) => { - error!("Connection is broken. Error during PING: {e:?}"); + warn!("Connection is broken. Error during PING: {e:?}"); } } self.con = Some(con); diff --git a/proxy/src/redis.rs b/proxy/src/redis/mod.rs similarity index 100% rename from proxy/src/redis.rs rename to proxy/src/redis/mod.rs diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 36a3443603..c3af6740cb 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -146,7 +146,7 @@ impl MessageHandler { { Ok(()) => {} Err(e) => { - tracing::error!("failed to cancel session: {e}"); + tracing::warn!("failed to cancel session: {e}"); } } } diff --git a/proxy/src/sasl.rs b/proxy/src/sasl/mod.rs similarity index 100% rename from proxy/src/sasl.rs rename to proxy/src/sasl/mod.rs diff --git a/proxy/src/scram.rs b/proxy/src/scram/mod.rs similarity index 100% rename from proxy/src/scram.rs rename to proxy/src/scram/mod.rs diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 89eeec3e6f..9e49478cf3 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -3,22 +3,24 @@ use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use tokio::net::{lookup_host, TcpStream}; -use tracing::{field::display, info}; +use tokio_postgres::types::ToSql; +use tracing::{debug, field::display, info}; use crate::{ auth::{ + self, backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo}, check_peer_addr_is_in_list, AuthError, }, compute, - config::{AuthenticationConfig, ProxyConfig}, - console::{ + config::ProxyConfig, + context::RequestMonitoring, + control_plane::{ errors::{GetAuthInfoError, WakeComputeError}, locks::ApiLocks, provider::ApiLockError, CachedNodeInfo, }, - context::RequestMonitoring, error::{ErrorKind, ReportableError, UserFacingError}, intern::EndpointIdInt, proxy::{ @@ -26,18 +28,21 @@ use crate::{ retry::{CouldRetry, ShouldRetryWakeCompute}, }, rate_limiter::EndpointRateLimiter, - Host, + EndpointId, Host, }; use super::{ conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}, http_conn_pool::{self, poll_http2_client}, + local_conn_pool::{self, LocalClient, LocalConnPool}, }; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc, + pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, + pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>, pub(crate) endpoint_rate_limiter: Arc, } @@ -45,18 +50,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, password: &[u8], ) -> Result { let user_info = user_info.clone(); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| user_info.clone()); + let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if config.ip_allowlist_check_enabled + if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); @@ -75,7 +75,6 @@ impl PoolingBackend { let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, - config, secret, &user_info.endpoint, true, @@ -87,9 +86,13 @@ impl PoolingBackend { } }; let ep = EndpointIdInt::from(&user_info.endpoint); - let auth_outcome = - crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) - .await?; + let auth_outcome = crate::auth::validate_password_and_exchange( + &self.config.authentication_config.thread_pool, + ep, + password, + secret, + ) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -109,13 +112,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, jwt: String, - ) -> Result<(), AuthError> { - match &self.config.auth_backend { - crate::auth::Backend::Console(console, ()) => { - config + ) -> Result { + match &self.auth_backend { + crate::auth::Backend::ControlPlane(console, ()) => { + self.config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -127,13 +130,18 @@ impl PoolingBackend { .await .map_err(|e| AuthError::auth_failed(e.to_string()))?; - Ok(()) + Ok(ComputeCredentials { + info: user_info.clone(), + keys: crate::auth::backend::ComputeCredentialKeys::None, + }) } - crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed( + crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed( "JWT login over web auth proxy is not supported", )), crate::auth::Backend::Local(_) => { - config + let keys = self + .config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -145,8 +153,10 @@ impl PoolingBackend { .await .map_err(|e| AuthError::auth_failed(e.to_string()))?; - // todo: rewrite JWT signature with key shared somehow between local proxy and postgres - Ok(()) + Ok(ComputeCredentials { + info: user_info.clone(), + keys, + }) } } } @@ -176,7 +186,7 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.config.auth_backend.as_ref().map(|()| keys); + let backend = self.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -208,14 +218,14 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| ComputeCredentials { - info: conn_info.user_info.clone(), - keys: crate::auth::backend::ComputeCredentialKeys::None, - }); + let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { + info: ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), + options: conn_info.user_info.options.clone(), + }, + keys: crate::auth::backend::ComputeCredentialKeys::None, + }); crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { @@ -231,6 +241,77 @@ impl PoolingBackend { ) .await } + + /// Connect to postgres over localhost. + /// + /// We expect postgres to be started here, so we won't do any retries. + /// + /// # Panics + /// + /// Panics if called with a non-local_proxy backend. + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + pub(crate) async fn connect_to_local_postgres( + &self, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + ) -> Result, HttpConnError> { + if let Some(client) = self.local_pool.get(ctx, &conn_info)? { + return Ok(client); + } + + let conn_id = uuid::Uuid::new_v4(); + tracing::Span::current().record("conn_id", display(conn_id)); + info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); + + let mut node_info = match &self.auth_backend { + auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => { + unreachable!("only local_proxy can connect to local postgres") + } + auth::Backend::Local(local) => local.node_info.clone(), + }; + + let config = node_info + .config + .user(&conn_info.user_info.user) + .dbname(&conn_info.dbname); + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + drop(pause); + + tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); + + let handle = local_conn_pool::poll_client( + self.local_pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + node_info.aux.clone(), + ); + + let kid = handle.get_client().get_process_id() as i64; + let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk(); + + debug!(kid, ?jwk, "setting up backend session state"); + + // initiates the auth session + handle + .get_client() + .query( + "select auth.init($1, $2);", + &[ + &kid as &(dyn ToSql + Sync), + &tokio_postgres::types::Json(jwk), + ], + ) + .await?; + + info!(?kid, "backend session state init"); + + Ok(handle) + } } #[derive(Debug, thiserror::Error)] @@ -241,6 +322,8 @@ pub(crate) enum HttpConnError { PostgresConnectionError(#[from] tokio_postgres::Error), #[error("could not connection to local-proxy in compute")] LocalProxyConnectionError(#[from] LocalProxyConnError), + #[error("could not parse JWT payload")] + JwtPayloadError(serde_json::Error), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), @@ -257,7 +340,7 @@ pub(crate) enum LocalProxyConnError { #[error("error with connection to local-proxy")] Io(#[source] std::io::Error), #[error("could not establish h2 connection")] - H2(#[from] hyper1::Error), + H2(#[from] hyper::Error), } impl ReportableError for HttpConnError { @@ -266,6 +349,7 @@ impl ReportableError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => p.get_error_kind(), HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, + HttpConnError::JwtPayloadError(_) => ErrorKind::User, HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(), HttpConnError::WakeCompute(w) => w.get_error_kind(), @@ -280,6 +364,7 @@ impl UserFacingError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), + HttpConnError::JwtPayloadError(p) => p.to_string(), HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(), HttpConnError::WakeCompute(c) => c.to_string_client(), @@ -296,6 +381,7 @@ impl CouldRetry for HttpConnError { HttpConnError::PostgresConnectionError(e) => e.could_retry(), HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), HttpConnError::ConnectionClosedAbruptly(_) => false, + HttpConnError::JwtPayloadError(_) => false, HttpConnError::GetAuthInfo(_) => false, HttpConnError::AuthError(_) => false, HttpConnError::WakeCompute(_) => false, @@ -422,8 +508,12 @@ impl ConnectMechanism for HyperMechanism { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - // let port = node_info.config.get_ports().first().unwrap_or_else(10432); - let res = connect_http2(&host, 10432, timeout).await; + let port = *node_info.config.get_ports().first().ok_or_else(|| { + HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress( + "local-proxy port missing on compute address".into(), + )) + })?; + let res = connect_http2(&host, port, timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -481,7 +571,7 @@ async fn connect_http2( }; }; - let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new()) + let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) .keep_alive_interval(Duration::from_secs(20)) .keep_alive_while_idle(true) diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index a850ecd2be..2e576e0ded 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -17,7 +17,7 @@ use tokio_postgres::tls::NoTlsStream; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; use tokio_util::sync::CancellationToken; -use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{ @@ -760,7 +760,7 @@ mod tests { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), branch_id: (&BranchId::from("branch")).into(), - cold_start_info: crate::console::messages::ColdStartInfo::Warm, + cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm, }, conn_id: uuid::Uuid::new_v4(), } diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index b31ed22a7c..6d61536f1a 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -1,5 +1,5 @@ use dashmap::DashMap; -use hyper1::client::conn::http2; +use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use rand::Rng; @@ -8,7 +8,7 @@ use std::sync::atomic::{self, AtomicUsize}; use std::{sync::Arc, sync::Weak}; use tokio::net::TcpStream; -use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{context::RequestMonitoring, EndpointCacheKey}; @@ -18,9 +18,9 @@ use tracing::{info, info_span, Instrument}; use super::conn_pool::ConnInfo; -pub(crate) type Send = http2::SendRequest; +pub(crate) type Send = http2::SendRequest; pub(crate) type Connect = - http2::Connection, hyper1::body::Incoming, TokioExecutor>; + http2::Connection, hyper::body::Incoming, TokioExecutor>; #[derive(Clone)] struct ConnPoolEntry { diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index d766a46577..87a72ec5f0 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -11,7 +11,7 @@ use serde::Serialize; use utils::http::error::ApiError; /// Like [`ApiError::into_response`] -pub(crate) fn api_error_into_response(this: ApiError) -> Response> { +pub(crate) fn api_error_into_response(this: ApiError) -> Response> { match this { ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( format!("{err:#?}"), // use debug printing so that we give the cause @@ -67,12 +67,12 @@ impl HttpErrorBody { fn response_from_msg_and_status( msg: String, status: StatusCode, - ) -> Response> { + ) -> Response> { HttpErrorBody { msg }.to_response(status) } /// Same as [`utils::http::error::HttpErrorBody::to_response`] - fn to_response(&self, status: StatusCode) -> Response> { + fn to_response(&self, status: StatusCode) -> Response> { Response::builder() .status(status) .header(http::header::CONTENT_TYPE, "application/json") @@ -90,7 +90,7 @@ impl HttpErrorBody { pub(crate) fn json_response( status: StatusCode, data: T, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let json = serde_json::to_string(&data) .context("Failed to serialize JSON response") .map_err(ApiError::InternalServerError)?; diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs new file mode 100644 index 0000000000..1dde5952e1 --- /dev/null +++ b/proxy/src/serverless/local_conn_pool.rs @@ -0,0 +1,544 @@ +use futures::{future::poll_fn, Future}; +use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding}; +use p256::ecdsa::{Signature, SigningKey}; +use parking_lot::RwLock; +use rand::rngs::OsRng; +use serde_json::Value; +use signature::Signer; +use std::task::{ready, Poll}; +use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; +use tokio::time::Instant; +use tokio_postgres::tls::NoTlsStream; +use tokio_postgres::types::ToSql; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_util::sync::CancellationToken; +use typed_json::json; + +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; +use crate::metrics::Metrics; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::{context::RequestMonitoring, DbName, RoleName}; + +use tracing::{debug, error, warn, Span}; +use tracing::{info, info_span, Instrument}; + +use super::backend::HttpConnError; +use super::conn_pool::{ClientInnerExt, ConnInfo}; + +struct ConnPoolEntry { + conn: ClientInner, + _last_access: std::time::Instant, +} + +// /// key id for the pg_session_jwt state +// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1); + +// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub(crate) struct EndpointConnPool { + pools: HashMap<(DbName, RoleName), DbUserConnPool>, + total_conns: usize, + max_conns: usize, + global_pool_size_max_conns: usize, +} + +impl EndpointConnPool { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { + let Self { + pools, total_conns, .. + } = self; + pools + .get_mut(&db_user) + .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) + } + + fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { + let Self { + pools, total_conns, .. + } = self; + if let Some(pool) = pools.get_mut(&db_user) { + let old_len = pool.conns.len(); + pool.conns.retain(|conn| conn.conn.conn_id != conn_id); + let new_len = pool.conns.len(); + let removed = old_len - new_len; + if removed > 0 { + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + } + *total_conns -= removed; + removed > 0 + } else { + false + } + } + + fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { + let conn_id = client.conn_id; + + if client.is_closed() { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed"); + return; + } + let global_max_conn = pool.read().global_pool_size_max_conns; + if pool.read().total_conns >= global_max_conn { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full"); + return; + } + + // return connection to the pool + let mut returned = false; + let mut per_db_size = 0; + let total_conns = { + let mut pool = pool.write(); + + if pool.total_conns < pool.max_conns { + let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); + + returned = true; + per_db_size = pool_entries.conns.len(); + + pool.total_conns += 1; + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .inc(); + } + + pool.total_conns + }; + + // do logging outside of the mutex + if returned { + info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + } else { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + } + } +} + +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if self.total_conns > 0 { + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(self.total_conns as i64); + } + } +} + +pub(crate) struct DbUserConnPool { + conns: Vec>, +} + +impl Default for DbUserConnPool { + fn default() -> Self { + Self { conns: Vec::new() } + } +} + +impl DbUserConnPool { + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { + let old_len = self.conns.len(); + + self.conns.retain(|conn| !conn.conn.is_closed()); + + let new_len = self.conns.len(); + let removed = old_len - new_len; + *conns -= removed; + removed + } + + fn get_conn_entry(&mut self, conns: &mut usize) -> Option> { + let mut removed = self.clear_closed_clients(conns); + let conn = self.conns.pop(); + if conn.is_some() { + *conns -= 1; + removed += 1; + } + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + conn + } +} + +pub(crate) struct LocalConnPool { + global_pool: RwLock>, + + config: &'static crate::config::HttpConfig, +} + +impl LocalConnPool { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { + Arc::new(Self { + global_pool: RwLock::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + max_conns: config.pool_options.max_conns_per_endpoint, + global_pool_size_max_conns: config.pool_options.max_total_conns, + }), + config, + }) + } + + pub(crate) fn get_idle_timeout(&self) -> Duration { + self.config.pool_options.idle_timeout + } + + // pub(crate) fn shutdown(&self) { + // let mut pool = self.global_pool.write(); + // pool.pools.clear(); + // pool.total_conns = 0; + // } + + pub(crate) fn get( + self: &Arc, + ctx: &RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result>, HttpConnError> { + let mut client: Option> = None; + if let Some(entry) = self + .global_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn); + } + + // ok return cached connection if found and establish a new one otherwise + if let Some(client) = client { + if client.is_closed() { + info!("local_pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); + } + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); + tracing::Span::current().record( + "pid", + tracing::field::display(client.inner.get_process_id()), + ); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "local_pool: reusing connection '{conn_info}'" + ); + client.session.send(ctx.session_id())?; + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + return Ok(Some(LocalClient::new( + client, + conn_info.clone(), + Arc::downgrade(self), + ))); + } + Ok(None) + } +} + +pub(crate) fn poll_client( + global_pool: Arc>, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + client: tokio_postgres::Client, + mut connection: tokio_postgres::Connection, + conn_id: uuid::Uuid, + aux: MetricsAuxInfo, +) -> LocalClient { + let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); + let mut session_id = ctx.session_id(); + let (tx, mut rx) = tokio::sync::watch::channel(session_id); + + let span = info_span!(parent: None, "connection", %conn_id); + let cold_start_info = ctx.cold_start_info(); + span.in_scope(|| { + info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); + }); + let pool = Arc::downgrade(&global_pool); + let pool_clone = pool.clone(); + + let db_user = conn_info.db_and_user(); + let idle = global_pool.get_idle_timeout(); + let cancel = CancellationToken::new(); + let cancelled = cancel.clone().cancelled_owned(); + + tokio::spawn( + async move { + let _conn_gauge = conn_gauge; + let mut idle_timeout = pin!(tokio::time::sleep(idle)); + let mut cancelled = pin!(cancelled); + + poll_fn(move |cx| { + if cancelled.as_mut().poll(cx).is_ready() { + info!("connection dropped"); + return Poll::Ready(()) + } + + match rx.has_changed() { + Ok(true) => { + session_id = *rx.borrow_and_update(); + info!(%session_id, "changed session"); + idle_timeout.as_mut().reset(Instant::now() + idle); + } + Err(_) => { + info!("connection dropped"); + return Poll::Ready(()) + } + _ => {} + } + + // 5 minute idle connection timeout + if idle_timeout.as_mut().poll(cx).is_ready() { + idle_timeout.as_mut().reset(Instant::now() + idle); + info!("connection idle"); + if let Some(pool) = pool.clone().upgrade() { + // remove client from pool - should close the connection if it's idle. + // does nothing if the client is currently checked-out and in-use + if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + info!("idle connection removed"); + } + } + } + + loop { + let message = ready!(connection.poll_message(cx)); + + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!(%session_id, "notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); + } + Some(Ok(_)) => { + warn!(%session_id, "unknown message"); + } + Some(Err(e)) => { + error!(%session_id, "connection error: {}", e); + break + } + None => { + info!("connection closed"); + break + } + } + } + + // remove from connection pool + if let Some(pool) = pool.clone().upgrade() { + if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + info!("closed connection removed"); + } + } + + Poll::Ready(()) + }).await; + + } + .instrument(span)); + + let key = SigningKey::random(&mut OsRng); + + let inner = ClientInner { + inner: client, + session: tx, + cancel, + aux, + conn_id, + key, + jti: 0, + }; + LocalClient::new(inner, conn_info, pool_clone) +} + +struct ClientInner { + inner: C, + session: tokio::sync::watch::Sender, + cancel: CancellationToken, + aux: MetricsAuxInfo, + conn_id: uuid::Uuid, + + // needed for pg_session_jwt state + key: SigningKey, + jti: u64, +} + +impl Drop for ClientInner { + fn drop(&mut self) { + // on client drop, tell the conn to shut down + self.cancel.cancel(); + } +} + +impl ClientInner { + pub(crate) fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl LocalClient { + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) + } +} + +pub(crate) struct LocalClient { + span: Span, + inner: Option>, + conn_info: ConnInfo, + pool: Weak>, +} + +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>, +} + +impl LocalClient { + pub(self) fn new( + inner: ClientInner, + conn_info: ConnInfo, + pool: Weak>, + ) -> Self { + Self { + inner: Some(inner), + span: Span::current(), + conn_info, + pool, + } + } + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { + let Self { + inner, + pool, + conn_info, + span: _, + } = self; + let inner = inner.as_mut().expect("client inner should not be removed"); + (&mut inner.inner, Discard { conn_info, pool }) + } + pub(crate) fn key(&self) -> &SigningKey { + let inner = &self + .inner + .as_ref() + .expect("client inner should not be removed"); + &inner.key + } +} + +impl LocalClient { + pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { + let inner = self + .inner + .as_mut() + .expect("client inner should not be removed"); + inner.jti += 1; + + let kid = inner.inner.get_process_id(); + let header = json!({"kid":kid}).to_string(); + + let mut payload = serde_json::from_slice::>(payload) + .map_err(HttpConnError::JwtPayloadError)?; + payload.insert("jti".to_string(), Value::Number(inner.jti.into())); + let payload = Value::Object(payload).to_string(); + + debug!( + kid, + jti = inner.jti, + ?header, + ?payload, + "signing new ephemeral JWT" + ); + + let token = sign_jwt(&inner.key, header, payload); + + // initiates the auth session + inner.inner.simple_query("discard all").await?; + inner + .inner + .query( + "select auth.jwt_session_init($1)", + &[&token as &(dyn ToSql + Sync)], + ) + .await?; + + info!(kid, jti = inner.jti, "user session state init"); + + Ok(()) + } +} + +fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String { + let header = Base64UrlUnpadded::encode_string(header.as_bytes()); + let payload = Base64UrlUnpadded::encode_string(payload.as_bytes()); + + let message = format!("{header}.{payload}"); + let sig: Signature = sk.sign(message.as_bytes()); + let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); + format!("{message}.{base64_sig}") +} + +impl Discard<'_, C> { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { + let conn_info = &self.conn_info; + if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { + info!( + "local_pool: throwing away connection '{conn_info}' because connection is not idle" + ); + } + } + pub(crate) fn discard(&mut self) { + let conn_info = &self.conn_info; + if std::mem::take(self.pool).strong_count() > 0 { + info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + } + } +} + +impl LocalClient { + pub fn get_client(&self) -> &C { + &self + .inner + .as_ref() + .expect("client inner should not be removed") + .inner + } + + fn do_drop(&mut self) -> Option { + let conn_info = self.conn_info.clone(); + let client = self + .inner + .take() + .expect("client inner should not be removed"); + if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { + let current_span = self.span.clone(); + // return connection to the pool + return Some(move || { + let _span = current_span.enter(); + EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client); + }); + } + None + } +} + +impl Drop for LocalClient { + fn drop(&mut self) { + if let Some(drop) = self.do_drop() { + tokio::task::spawn_blocking(drop); + } + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless/mod.rs similarity index 95% rename from proxy/src/serverless.rs rename to proxy/src/serverless/mod.rs index a7e3fa709b..95f64e972c 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless/mod.rs @@ -8,6 +8,7 @@ mod conn_pool; mod http_conn_pool; mod http_util; mod json; +mod local_conn_pool; mod sql_over_http; mod websocket; @@ -22,7 +23,7 @@ use futures::TryFutureExt; use http::{Method, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty}; -use hyper1::body::Incoming; +use hyper::body::Incoming; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; @@ -47,13 +48,14 @@ use std::pin::{pin, Pin}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn, Instrument}; +use tracing::{info, warn, Instrument}; use utils::http::error::ApiError; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ws_listener: TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -63,6 +65,7 @@ pub async fn task_main( info!("websocket server has shut down"); } + let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config); let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); { let conn_pool = Arc::clone(&conn_pool); @@ -105,8 +108,10 @@ pub async fn task_main( let backend = Arc::new(PoolingBackend { http_conn_pool: Arc::clone(&http_conn_pool), + local_pool, pool: Arc::clone(&conn_pool), config, + auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); let tls_acceptor: Arc = match config.tls_config.as_ref() { @@ -238,7 +243,7 @@ async fn connection_startup( let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, Err(e) => { - tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); + tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); return None; } }; @@ -302,7 +307,7 @@ async fn connection_handler( let server = Builder::new(TokioExecutor::new()); let conn = server.serve_connection_with_upgrades( hyper_util::rt::TokioIo::new(conn), - hyper1::service::service_fn(move |req: hyper1::Request| { + hyper::service::service_fn(move |req: hyper::Request| { // First HTTP request shares the same session ID let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); @@ -355,7 +360,7 @@ async fn connection_handler( #[allow(clippy::too_many_arguments)] async fn request_handler( - mut request: hyper1::Request, + mut request: hyper::Request, config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, @@ -365,7 +370,7 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let host = request .headers() .get("host") @@ -394,6 +399,7 @@ async fn request_handler( async move { if let Err(e) = websocket::serve_websocket( config, + backend.auth_backend, ctx, websocket, cancellation_handler, @@ -402,7 +408,7 @@ async fn request_handler( ) .await { - error!("error in websocket connection: {e:#}"); + warn!("error in websocket connection: {e:#}"); } } .instrument(span), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f3a7ed9329..cf3324926c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -12,14 +12,14 @@ use http::Method; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use http_body_util::Full; -use hyper1::body::Body; -use hyper1::body::Incoming; -use hyper1::header; -use hyper1::http::HeaderName; -use hyper1::http::HeaderValue; -use hyper1::Response; -use hyper1::StatusCode; -use hyper1::{HeaderMap, Request}; +use hyper::body::Body; +use hyper::body::Incoming; +use hyper::header; +use hyper::http::HeaderName; +use hyper::http::HeaderValue; +use hyper::Response; +use hyper::StatusCode; +use hyper::{HeaderMap, Request}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; @@ -40,11 +40,12 @@ use url::Url; use urlencoding; use utils::http::error::ApiError; -use crate::auth::backend::ComputeCredentials; +use crate::auth::backend::ComputeCredentialKeys; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; use crate::config::AuthenticationConfig; +use crate::config::HttpConfig; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -56,20 +57,22 @@ use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; use crate::proxy::NeonOptions; use crate::serverless::backend::HttpConnError; +use crate::usage_metrics::MetricCounter; use crate::usage_metrics::MetricCounterRecorder; use crate::DbName; use crate::RoleName; use super::backend::LocalProxyConnError; use super::backend::PoolingBackend; +use super::conn_pool; use super::conn_pool::AuthData; -use super::conn_pool::Client; use super::conn_pool::ConnInfo; use super::conn_pool::ConnInfoWithAuth; use super::http_util::json_response; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; +use super::local_conn_pool; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -272,7 +275,7 @@ pub(crate) async fn handle( request: Request, backend: Arc, cancel: CancellationToken, -) -> Result>, ApiError> { +) -> Result>, ApiError> { let result = handle_inner(cancel, config, &ctx, request, backend).await; let mut response = match result { @@ -435,7 +438,7 @@ impl UserFacingError for SqlOverHttpError { #[derive(Debug, thiserror::Error)] pub(crate) enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] - Read(#[from] hyper1::Error), + Read(#[from] hyper::Error), #[error("could not parse the HTTP request body: {0}")] Parse(#[from] serde_json::Error), } @@ -476,7 +479,7 @@ struct HttpHeaders { } impl HttpHeaders { - fn try_parse(headers: &hyper1::http::HeaderMap) -> Result { + fn try_parse(headers: &hyper::http::HeaderMap) -> Result { // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); @@ -529,7 +532,7 @@ async fn handle_inner( ctx: &RequestMonitoring, request: Request, backend: Arc, -) -> Result>, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { let _requeset_gauge = Metrics::get() .proxy .connection_requests @@ -552,7 +555,7 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await + handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await } auth => { handle_db_inner( @@ -577,7 +580,7 @@ async fn handle_db_inner( conn_info: ConnInfo, auth: AuthData, backend: Arc, -) -> Result>, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { // // Determine the destination and connection params // @@ -620,37 +623,35 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { + let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_)); + let keys = match auth { AuthData::Password(pw) => { backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.user_info, - &pw, - ) + .authenticate_with_password(ctx, &conn_info.user_info, &pw) .await? } AuthData::Jwt(jwt) => { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) - .await?; - - ComputeCredentials { - info: conn_info.user_info.clone(), - keys: crate::auth::backend::ComputeCredentialKeys::None, - } + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await? + } + }; + + let client = match keys.keys { + ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => { + let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; + client.set_jwt_session(&payload).await?; + Client::Local(client) + } + _ => { + let client = backend + .connect_to_compute(ctx, conn_info, keys, !allow_pool) + .await?; + Client::Remote(client) } }; - let client = backend - .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await?; // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.success(); @@ -680,7 +681,7 @@ async fn handle_db_inner( // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { - stmt.process(config, cancel, &mut client, parsed_headers) + stmt.process(&config.http_config, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { @@ -698,7 +699,7 @@ async fn handle_db_inner( } statements - .process(config, cancel, &mut client, parsed_headers) + .process(&config.http_config, cancel, &mut client, parsed_headers) .await? } }; @@ -738,20 +739,14 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[ ]; async fn handle_auth_broker_inner( - config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, jwt: String, backend: Arc, -) -> Result>, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await .map_err(HttpConnError::from)?; @@ -789,9 +784,9 @@ async fn handle_auth_broker_inner( impl QueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, - client: &mut Client, + client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { let (inner, mut discard) = client.inner(); @@ -820,7 +815,7 @@ impl QueryData { Either::Right((_cancelled, query)) => { tracing::info!("cancelling query"); if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // wait for the query cancellation match time::timeout(time::Duration::from_millis(100), query).await { @@ -863,9 +858,9 @@ impl QueryData { impl BatchQueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, - client: &mut Client, + client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { info!("starting transaction"); @@ -909,7 +904,7 @@ impl BatchQueryData { } Err(SqlOverHttpError::Cancelled(_)) => { if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. discard.discard(); @@ -933,7 +928,7 @@ impl BatchQueryData { } async fn query_batch( - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -972,7 +967,7 @@ async fn query_batch( } async fn query_to_json( - config: &'static ProxyConfig, + config: &'static HttpConfig, client: &T, data: QueryData, current_size: &mut usize, @@ -993,9 +988,9 @@ async fn query_to_json( rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) - if *current_size > config.http_config.max_response_size_bytes { + if *current_size > config.max_response_size_bytes { return Err(SqlOverHttpError::ResponseTooLarge( - config.http_config.max_response_size_bytes, + config.max_response_size_bytes, )); } } @@ -1058,3 +1053,50 @@ async fn query_to_json( Ok((ready, results)) } + +enum Client { + Remote(conn_pool::Client), + Local(local_conn_pool::LocalClient), +} + +enum Discard<'a> { + Remote(conn_pool::Discard<'a, tokio_postgres::Client>), + Local(local_conn_pool::Discard<'a, tokio_postgres::Client>), +} + +impl Client { + fn metrics(&self) -> Arc { + match self { + Client::Remote(client) => client.metrics(), + Client::Local(local_client) => local_client.metrics(), + } + } + + fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + match self { + Client::Remote(client) => { + let (c, d) = client.inner(); + (c, Discard::Remote(d)) + } + Client::Local(local_client) => { + let (c, d) = local_client.inner(); + (c, Discard::Local(d)) + } + } + } +} + +impl Discard<'_> { + fn check_idle(&mut self, status: ReadyForQueryStatus) { + match self { + Discard::Remote(discard) => discard.check_idle(status), + Discard::Local(discard) => discard.check_idle(status), + } + } + fn discard(&mut self) { + match self { + Discard::Remote(discard) => discard.discard(), + Discard::Local(discard) => discard.discard(), + } + } +} diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 3d257223b8..fd0f0cac7f 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -12,7 +12,7 @@ use anyhow::Context as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use framed_websockets::{Frame, OpCode, WebSocketServer}; use futures::{Sink, Stream}; -use hyper1::upgrade::OnUpgrade; +use hyper::upgrade::OnUpgrade; use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; @@ -129,6 +129,7 @@ impl AsyncBufRead for WebSocketRw { pub(crate) async fn serve_websocket( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc, @@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket( let res = Box::pin(handle_client( config, + auth_backend, &ctx, cancellation_handler, WebSocketRw::new(websocket), diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index fd8599bcb3..ee36ed462d 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -27,7 +27,7 @@ use std::{ }; use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; -use tracing::{error, info, instrument, trace}; +use tracing::{error, info, instrument, trace, warn}; use utils::backoff; use uuid::{NoContext, Timestamp}; @@ -346,7 +346,7 @@ async fn collect_metrics_iteration( error!("metrics endpoint refused the sent metrics: {:?}", res); for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) { // Report if the metric value is suspiciously large - error!("potentially abnormal metric value: {:?}", metric); + warn!("potentially abnormal metric value: {:?}", metric); } } } @@ -485,49 +485,51 @@ async fn upload_events_chunk( #[cfg(test)] mod tests { - use std::{ - net::TcpListener, - sync::{Arc, Mutex}, - }; + use super::*; + use crate::{http, BranchId, EndpointId}; use anyhow::Error; use chrono::Utc; use consumption_metrics::{Event, EventChunk}; - use hyper::{ - service::{make_service_fn, service_fn}, - Body, Response, - }; + use http_body_util::BodyExt; + use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response}; + use hyper_util::rt::TokioIo; + use std::sync::{Arc, Mutex}; + use tokio::net::TcpListener; use url::Url; - use super::*; - use crate::{http, BranchId, EndpointId}; - #[tokio::test] async fn metrics() { - let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + type Report = EventChunk<'static, Event>; + let reports: Arc>> = Arc::default(); - let reports = Arc::new(Mutex::new(vec![])); - let reports2 = reports.clone(); - - let server = hyper::server::Server::from_tcp(listener) - .unwrap() - .serve(make_service_fn(move |_| { - let reports = reports.clone(); - async move { - Ok::<_, Error>(service_fn(move |req| { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn({ + let reports = reports.clone(); + async move { + loop { + if let Ok((stream, _addr)) = listener.accept().await { let reports = reports.clone(); - async move { - let bytes = hyper::body::to_bytes(req.into_body()).await?; - let events: EventChunk<'static, Event> = - serde_json::from_slice(&bytes)?; - reports.lock().unwrap().push(events); - Ok::<_, Error>(Response::new(Body::from(vec![]))) - } - })) + http1::Builder::new() + .serve_connection( + TokioIo::new(stream), + service_fn(move |req: Request| { + let reports = reports.clone(); + async move { + let bytes = req.into_body().collect().await?.to_bytes(); + let events = serde_json::from_slice(&bytes)?; + reports.lock().unwrap().push(events); + Ok::<_, Error>(Response::new(String::new())) + } + }), + ) + .await + .unwrap(); + } } - })); - let addr = server.local_addr(); - tokio::spawn(server); + } + }); let metrics = Metrics::default(); let client = http::new_client(); @@ -536,7 +538,7 @@ mod tests { // no counters have been registered collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; - let r = std::mem::take(&mut *reports2.lock().unwrap()); + let r = std::mem::take(&mut *reports.lock().unwrap()); assert!(r.is_empty()); // register a new counter @@ -548,7 +550,7 @@ mod tests { // the counter should be observed despite 0 egress collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; - let r = std::mem::take(&mut *reports2.lock().unwrap()); + let r = std::mem::take(&mut *reports.lock().unwrap()); assert_eq!(r.len(), 1); assert_eq!(r[0].events.len(), 1); assert_eq!(r[0].events[0].value, 0); @@ -558,7 +560,7 @@ mod tests { // egress should be observered collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; - let r = std::mem::take(&mut *reports2.lock().unwrap()); + let r = std::mem::take(&mut *reports.lock().unwrap()); assert_eq!(r.len(), 1); assert_eq!(r[0].events.len(), 1); assert_eq!(r[0].events[0].value, 1); @@ -568,7 +570,7 @@ mod tests { // we do not observe the counter collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; - let r = std::mem::take(&mut *reports2.lock().unwrap()); + let r = std::mem::take(&mut *reports.lock().unwrap()); assert!(r.is_empty()); // counter is unregistered diff --git a/pyproject.toml b/pyproject.toml index 556edf5589..9cd315bb96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,5 +97,8 @@ select = [ "I", # isort "W", # pycodestyle "B", # bugbear - "UP032", # f-string + "UP", # pyupgrade ] + +[tool.ruff.lint.pyupgrade] +keep-runtime-typing = true # Remove this stanza when we require Python 3.10 diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 78a3129aba..ec08d02240 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -23,6 +23,7 @@ crc32c.workspace = true fail.workspace = true hex.workspace = true humantime.workspace = true +http.workspace = true hyper0.workspace = true futures.workspace = true once_cell.workspace = true diff --git a/safekeeper/src/auth.rs b/safekeeper/src/auth.rs index c5c9393c00..fdd0830b02 100644 --- a/safekeeper/src/auth.rs +++ b/safekeeper/src/auth.rs @@ -15,15 +15,20 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } Ok(()) } - (Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Safekeeper auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::PageServerApi + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Safekeeper auth", + claims.scope + ) + .into(), + )), (Scope::SafekeeperData, _) => Ok(()), } } diff --git a/safekeeper/src/metrics.rs b/safekeeper/src/metrics.rs index aa2bafbe92..e8fdddcdc1 100644 --- a/safekeeper/src/metrics.rs +++ b/safekeeper/src/metrics.rs @@ -12,8 +12,8 @@ use metrics::{ core::{AtomicU64, Collector, Desc, GenericCounter, GenericGaugeVec, Opts}, proto::MetricFamily, register_histogram_vec, register_int_counter, register_int_counter_pair, - register_int_counter_pair_vec, register_int_counter_vec, Gauge, HistogramVec, IntCounter, - IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec, + register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, Gauge, + HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec, }; use once_cell::sync::Lazy; @@ -231,6 +231,14 @@ pub(crate) static EVICTION_EVENTS_COMPLETED: Lazy = Lazy::new(|| .expect("Failed to register metric") }); +pub static NUM_EVICTED_TIMELINES: Lazy = Lazy::new(|| { + register_int_gauge!( + "safekeeper_evicted_timelines", + "Number of currently evicted timelines" + ) + .expect("Failed to register metric") +}); + pub const LABEL_UNKNOWN: &str = "unknown"; /// Labels for traffic metrics. diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index fb98534768..3494b0b764 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -631,13 +631,19 @@ impl Timeline { return Err(e); } - self.bootstrap(conf, broker_active_set, partial_backup_rate_limiter); + self.bootstrap( + shared_state, + conf, + broker_active_set, + partial_backup_rate_limiter, + ); Ok(()) } /// Bootstrap new or existing timeline starting background tasks. pub fn bootstrap( self: &Arc, + _shared_state: &mut WriteGuardSharedState<'_>, conf: &SafeKeeperConf, broker_active_set: Arc, partial_backup_rate_limiter: RateLimiter, diff --git a/safekeeper/src/timeline_eviction.rs b/safekeeper/src/timeline_eviction.rs index 5aa4921a92..fae6571277 100644 --- a/safekeeper/src/timeline_eviction.rs +++ b/safekeeper/src/timeline_eviction.rs @@ -15,7 +15,9 @@ use tracing::{debug, info, instrument, warn}; use utils::crashsafe::durable_rename; use crate::{ - metrics::{EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED}, + metrics::{ + EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED, NUM_EVICTED_TIMELINES, + }, rate_limit::rand_duration, timeline_manager::{Manager, StateSnapshot}, wal_backup, @@ -93,6 +95,7 @@ impl Manager { } info!("successfully evicted timeline"); + NUM_EVICTED_TIMELINES.inc(); } /// Attempt to restore evicted timeline from remote storage; it must be @@ -128,6 +131,7 @@ impl Manager { tokio::time::Instant::now() + rand_duration(&self.conf.eviction_min_resident); info!("successfully restored evicted timeline"); + NUM_EVICTED_TIMELINES.dec(); } } diff --git a/safekeeper/src/timeline_manager.rs b/safekeeper/src/timeline_manager.rs index f5535c0cea..2129e86baa 100644 --- a/safekeeper/src/timeline_manager.rs +++ b/safekeeper/src/timeline_manager.rs @@ -25,7 +25,10 @@ use utils::lsn::Lsn; use crate::{ control_file::{FileStorage, Storage}, - metrics::{MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS}, + metrics::{ + MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS, + NUM_EVICTED_TIMELINES, + }, rate_limit::{rand_duration, RateLimiter}, recovery::recovery_main, remove_wal::calc_horizon_lsn, @@ -251,6 +254,11 @@ pub async fn main_task( mgr.recovery_task = Some(tokio::spawn(recovery_main(tli, mgr.conf.clone()))); } + // If timeline is evicted, reflect that in the metric. + if mgr.is_offloaded { + NUM_EVICTED_TIMELINES.inc(); + } + let last_state = 'outer: loop { MANAGER_ITERATIONS_TOTAL.inc(); @@ -367,6 +375,11 @@ pub async fn main_task( mgr.update_wal_removal_end(res); } + // If timeline is deleted while evicted decrement the gauge. + if mgr.tli.is_cancelled() && mgr.is_offloaded { + NUM_EVICTED_TIMELINES.dec(); + } + mgr.set_status(Status::Finished); } diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index 6662e18817..866cde3339 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -165,12 +165,14 @@ impl GlobalTimelines { match Timeline::load_timeline(&conf, ttid) { Ok(timeline) => { let tli = Arc::new(timeline); + let mut shared_state = tli.write_shared_state().await; TIMELINES_STATE .lock() .unwrap() .timelines .insert(ttid, tli.clone()); tli.bootstrap( + &mut shared_state, &conf, broker_active_set.clone(), partial_backup_rate_limiter.clone(), @@ -213,6 +215,7 @@ impl GlobalTimelines { match Timeline::load_timeline(&conf, ttid) { Ok(timeline) => { let tli = Arc::new(timeline); + let mut shared_state = tli.write_shared_state().await; // TODO: prevent concurrent timeline creation/loading { @@ -227,8 +230,13 @@ impl GlobalTimelines { state.timelines.insert(ttid, tli.clone()); } - tli.bootstrap(&conf, broker_active_set, partial_backup_rate_limiter); - + tli.bootstrap( + &mut shared_state, + &conf, + broker_active_set, + partial_backup_rate_limiter, + ); + drop(shared_state); Ok(tli) } // If we can't load a timeline, it's bad. Caller will figure it out. diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index ef26ac99c5..6c87e5a926 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -17,7 +17,9 @@ use std::time::Duration; use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr; use postgres_ffi::XLogFileName; use postgres_ffi::{XLogSegNo, PG_TLI}; -use remote_storage::{GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata}; +use remote_storage::{ + DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata, +}; use tokio::fs::File; use tokio::select; @@ -503,8 +505,12 @@ pub async fn read_object( let cancel = CancellationToken::new(); + let opts = DownloadOpts { + byte_start: std::ops::Bound::Included(offset), + ..Default::default() + }; let download = storage - .download_storage_object(Some((offset, None)), file_path, &cancel) + .download(file_path, &opts, &cancel) .await .with_context(|| { format!("Failed to open WAL segment download stream for remote path {file_path:?}") diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs index a05c2d4559..047b4be8fa 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -13,7 +13,7 @@ use desim::{ node_os::NodeOs, proto::{AnyMessage, NetEvent, NodeEvent}, }; -use hyper0::Uri; +use http::Uri; use safekeeper::{ safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION}, state::{TimelinePersistentState, TimelineState}, diff --git a/scripts/benchmark_durations.py b/scripts/benchmark_durations.py index 4ca433679a..a9a90c7370 100755 --- a/scripts/benchmark_durations.py +++ b/scripts/benchmark_durations.py @@ -1,9 +1,10 @@ #! /usr/bin/env python3 +from __future__ import annotations + import argparse import json import logging -from typing import Dict import psycopg2 import psycopg2.extras @@ -110,7 +111,7 @@ def main(args: argparse.Namespace): output = args.output percentile = args.percentile - res: Dict[str, float] = {} + res: dict[str, float] = {} try: logging.info("connecting to the database...") diff --git a/scripts/download_basebackup.py b/scripts/download_basebackup.py index 1f84e41fef..f00ee87eb7 100755 --- a/scripts/download_basebackup.py +++ b/scripts/download_basebackup.py @@ -4,6 +4,9 @@ # # This can be useful in disaster recovery. # + +from __future__ import annotations + import argparse import psycopg2 diff --git a/scripts/flaky_tests.py b/scripts/flaky_tests.py index 919a9278a9..9312f8b3e7 100755 --- a/scripts/flaky_tests.py +++ b/scripts/flaky_tests.py @@ -1,16 +1,21 @@ #! /usr/bin/env python3 +from __future__ import annotations + import argparse import json import logging import os from collections import defaultdict -from typing import Any, DefaultDict, Dict, Optional +from typing import TYPE_CHECKING import psycopg2 import psycopg2.extras import toml +if TYPE_CHECKING: + from typing import Any, Optional + FLAKY_TESTS_QUERY = """ SELECT DISTINCT parent_suite, suite, name @@ -33,7 +38,7 @@ def main(args: argparse.Namespace): build_type = args.build_type pg_version = args.pg_version - res: DefaultDict[str, DefaultDict[str, Dict[str, bool]]] + res: defaultdict[str, defaultdict[str, dict[str, bool]]] res = defaultdict(lambda: defaultdict(dict)) try: @@ -60,7 +65,7 @@ def main(args: argparse.Namespace): pageserver_virtual_file_io_engine_parameter = "" # re-use existing records of flaky tests from before parametrization by compaction_algorithm - def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict[str, Any]]: + def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: """Duplicated from parametrize.py""" toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM") if toml_table is None: diff --git a/scripts/force_layer_download.py b/scripts/force_layer_download.py index 5472d86d8f..a4fd3f6132 100644 --- a/scripts/force_layer_download.py +++ b/scripts/force_layer_download.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import json @@ -5,11 +7,15 @@ import logging import signal import sys from collections import defaultdict +from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any, Awaitable, Dict, List, Tuple +from typing import TYPE_CHECKING import aiohttp +if TYPE_CHECKING: + from typing import Any + class ClientException(Exception): pass @@ -89,7 +95,7 @@ class Client: class Completed: """The status dict returned by the API""" - status: Dict[str, Any] + status: dict[str, Any] sigint_received = asyncio.Event() @@ -179,7 +185,7 @@ async def main_impl(args, report_out, client: Client): """ Returns OS exit status. """ - tenant_and_timline_ids: List[Tuple[str, str]] = [] + tenant_and_timline_ids: list[tuple[str, str]] = [] # fill tenant_and_timline_ids based on spec for spec in args.what: comps = spec.split(":") @@ -215,14 +221,14 @@ async def main_impl(args, report_out, client: Client): tenant_and_timline_ids = tmp logging.info("create tasks and process them at specified concurrency") - task_q: asyncio.Queue[Tuple[str, Awaitable[Any]]] = asyncio.Queue() + task_q: asyncio.Queue[tuple[str, Awaitable[Any]]] = asyncio.Queue() tasks = { f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids } for task in tasks.items(): task_q.put_nowait(task) - result_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue() + result_q: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() taskq_handlers = [] for _ in range(0, args.concurrent_tasks): taskq_handlers.append(taskq_handler(task_q, result_q)) diff --git a/scripts/ingest_perf_test_result.py b/scripts/ingest_perf_test_result.py index 35a1e29720..40071c01b0 100644 --- a/scripts/ingest_perf_test_result.py +++ b/scripts/ingest_perf_test_result.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import argparse import json import logging diff --git a/scripts/ingest_regress_test_result-new-format.py b/scripts/ingest_regress_test_result-new-format.py index 40d7254e00..e0dd0a7189 100644 --- a/scripts/ingest_regress_test_result-new-format.py +++ b/scripts/ingest_regress_test_result-new-format.py @@ -1,5 +1,7 @@ #! /usr/bin/env python3 +from __future__ import annotations + import argparse import dataclasses import json @@ -11,7 +13,6 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Tuple import backoff import psycopg2 @@ -91,7 +92,7 @@ def create_table(cur): cur.execute(CREATE_TABLE) -def parse_test_name(test_name: str) -> Tuple[str, int, str]: +def parse_test_name(test_name: str) -> tuple[str, int, str]: build_type, pg_version = None, None if match := TEST_NAME_RE.search(test_name): found = match.groupdict() diff --git a/scripts/sk_cleanup_tenants/script.py b/scripts/sk_cleanup_tenants/script.py index c20a4bb830..8af19ae7bd 100644 --- a/scripts/sk_cleanup_tenants/script.py +++ b/scripts/sk_cleanup_tenants/script.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import logging import os diff --git a/storage_broker/Cargo.toml b/storage_broker/Cargo.toml index 849707fbc4..2d19472c36 100644 --- a/storage_broker/Cargo.toml +++ b/storage_broker/Cargo.toml @@ -10,13 +10,16 @@ bench = [] [dependencies] anyhow.workspace = true async-stream.workspace = true +bytes.workspace = true clap = { workspace = true, features = ["derive"] } const_format.workspace = true futures.workspace = true futures-core.workspace = true futures-util.workspace = true humantime.workspace = true -hyper0 = { workspace = true, features = ["full"] } +hyper = { workspace = true, features = ["full"] } +http-body-util.workspace = true +hyper-util = "0.1" once_cell.workspace = true parking_lot.workspace = true prost.workspace = true diff --git a/storage_broker/src/bin/storage_broker.rs b/storage_broker/src/bin/storage_broker.rs index 9c56e9fab5..1fbb651656 100644 --- a/storage_broker/src/bin/storage_broker.rs +++ b/storage_broker/src/bin/storage_broker.rs @@ -10,16 +10,15 @@ //! //! Only safekeeper message is supported, but it is not hard to add something //! else with generics. - -extern crate hyper0 as hyper; - use clap::{command, Parser}; use futures_core::Stream; use futures_util::StreamExt; +use http_body_util::Full; +use hyper::body::Incoming; use hyper::header::CONTENT_TYPE; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, StatusCode}; +use hyper::service::service_fn; +use hyper::{Method, StatusCode}; +use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use parking_lot::RwLock; use std::collections::HashMap; use std::convert::Infallible; @@ -27,9 +26,11 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use tokio::net::TcpListener; use tokio::sync::broadcast; use tokio::sync::broadcast::error::RecvError; use tokio::time; +use tonic::body::{self, empty_body, BoxBody}; use tonic::codegen::Service; use tonic::transport::server::Connected; use tonic::Code; @@ -48,9 +49,7 @@ use storage_broker::proto::{ FilterTenantTimelineId, MessageType, SafekeeperDiscoveryRequest, SafekeeperDiscoveryResponse, SafekeeperTimelineInfo, SubscribeByFilterRequest, SubscribeSafekeeperInfoRequest, TypedMessage, }; -use storage_broker::{ - parse_proto_ttid, EitherBody, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR, -}; +use storage_broker::{parse_proto_ttid, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR}; use utils::id::TenantTimelineId; use utils::logging::{self, LogFormat}; use utils::sentry_init::init_sentry; @@ -602,8 +601,8 @@ impl BrokerService for Broker { // We serve only metrics and healthcheck through http1. async fn http1_handler( - req: hyper::Request, -) -> Result, Infallible> { + req: hyper::Request, +) -> Result, Infallible> { let resp = match (req.method(), req.uri().path()) { (&Method::GET, "/metrics") => { let mut buffer = vec![]; @@ -614,16 +613,16 @@ async fn http1_handler( hyper::Response::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, encoder.format_type()) - .body(Body::from(buffer)) + .body(body::boxed(Full::new(bytes::Bytes::from(buffer)))) .unwrap() } (&Method::GET, "/status") => hyper::Response::builder() .status(StatusCode::OK) - .body(Body::empty()) + .body(empty_body()) .unwrap(), _ => hyper::Response::builder() .status(StatusCode::NOT_FOUND) - .body(Body::empty()) + .body(empty_body()) .unwrap(), }; Ok(resp) @@ -665,52 +664,76 @@ async fn main() -> Result<(), Box> { }; let storage_broker_server = BrokerServiceServer::new(storage_broker_impl); - info!("listening on {}", &args.listen_addr); - // grpc is served along with http1 for metrics on a single port, hence we // don't use tonic's Server. - hyper::Server::bind(&args.listen_addr) - .http2_keep_alive_interval(Some(args.http2_keepalive_interval)) - .serve(make_service_fn(move |conn: &AddrStream| { - let storage_broker_server_cloned = storage_broker_server.clone(); - let connect_info = conn.connect_info(); - async move { - Ok::<_, Infallible>(service_fn(move |mut req| { - // That's what tonic's MakeSvc.call does to pass conninfo to - // the request handler (and where its request.remote_addr() - // expects it to find). - req.extensions_mut().insert(connect_info.clone()); - - // Technically this second clone is not needed, but consume - // by async block is apparently unavoidable. BTW, error - // message is enigmatic, see - // https://github.com/rust-lang/rust/issues/68119 - // - // We could get away without async block at all, but then we - // need to resort to futures::Either to merge the result, - // which doesn't caress an eye as well. - let mut storage_broker_server_svc = storage_broker_server_cloned.clone(); - async move { - if req.headers().get("content-type").map(|x| x.as_bytes()) - == Some(b"application/grpc") - { - let res_resp = storage_broker_server_svc.call(req).await; - // Grpc and http1 handlers have slightly different - // Response types: it is UnsyncBoxBody for the - // former one (not sure why) and plain hyper::Body - // for the latter. Both implement HttpBody though, - // and EitherBody is used to merge them. - res_resp.map(|resp| resp.map(EitherBody::Left)) - } else { - let res_resp = http1_handler(req).await; - res_resp.map(|resp| resp.map(EitherBody::Right)) - } - } - })) + let tcp_listener = TcpListener::bind(&args.listen_addr).await?; + info!("listening on {}", &args.listen_addr); + loop { + let (stream, addr) = match tcp_listener.accept().await { + Ok(v) => v, + Err(e) => { + info!("couldn't accept connection: {e}"); + continue; } - })) - .await?; - Ok(()) + }; + + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + builder.http1().timer(TokioTimer::new()); + builder + .http2() + .timer(TokioTimer::new()) + .keep_alive_interval(Some(args.http2_keepalive_interval)) + // This matches the tonic server default. It allows us to support production-like workloads. + .max_concurrent_streams(None); + + let storage_broker_server_cloned = storage_broker_server.clone(); + let connect_info = stream.connect_info(); + let service_fn_ = async move { + service_fn(move |mut req| { + // That's what tonic's MakeSvc.call does to pass conninfo to + // the request handler (and where its request.remote_addr() + // expects it to find). + req.extensions_mut().insert(connect_info.clone()); + + // Technically this second clone is not needed, but consume + // by async block is apparently unavoidable. BTW, error + // message is enigmatic, see + // https://github.com/rust-lang/rust/issues/68119 + // + // We could get away without async block at all, but then we + // need to resort to futures::Either to merge the result, + // which doesn't caress an eye as well. + let mut storage_broker_server_svc = storage_broker_server_cloned.clone(); + async move { + if req.headers().get("content-type").map(|x| x.as_bytes()) + == Some(b"application/grpc") + { + let res_resp = storage_broker_server_svc.call(req).await; + // Grpc and http1 handlers have slightly different + // Response types: it is UnsyncBoxBody for the + // former one (not sure why) and plain hyper::Body + // for the latter. Both implement HttpBody though, + // and `Either` is used to merge them. + res_resp.map(|resp| resp.map(http_body_util::Either::Left)) + } else { + let res_resp = http1_handler(req).await; + res_resp.map(|resp| resp.map(http_body_util::Either::Right)) + } + } + }) + } + .await; + + tokio::task::spawn(async move { + let res = builder + .serve_connection(TokioIo::new(stream), service_fn_) + .await; + + if let Err(e) = res { + info!("error serving connection from {addr}: {e}"); + } + }); + } } #[cfg(test)] diff --git a/storage_broker/src/lib.rs b/storage_broker/src/lib.rs index 447591f898..bc632a39f7 100644 --- a/storage_broker/src/lib.rs +++ b/storage_broker/src/lib.rs @@ -1,8 +1,3 @@ -extern crate hyper0 as hyper; - -use hyper::body::HttpBody; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::time::Duration; use tonic::codegen::StdError; use tonic::transport::{ClientTlsConfig, Endpoint}; @@ -96,56 +91,3 @@ pub fn parse_proto_ttid(proto_ttid: &ProtoTenantTimelineId) -> Result; - -// Provides impl HttpBody for two different types implementing it. Inspired by -// https://github.com/hyperium/tonic/blob/master/examples/src/hyper_warp/server.rs -pub enum EitherBody { - Left(A), - Right(B), -} - -impl HttpBody for EitherBody -where - A: HttpBody + Send + Unpin, - B: HttpBody + Send + Unpin, - A::Error: Into, - B::Error: Into, -{ - type Data = A::Data; - type Error = Box; - - fn is_end_stream(&self) -> bool { - match self { - EitherBody::Left(b) => b.is_end_stream(), - EitherBody::Right(b) => b.is_end_stream(), - } - } - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err), - EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - } - } -} - -fn map_option_err>(err: Option>) -> Option> { - err.map(|e| e.map_err(Into::into)) -} diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 4dd8badd03..46b6f4f2bf 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -636,7 +636,7 @@ async fn handle_tenant_list( } async fn handle_node_register(req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let mut req = match maybe_forward(req).await { ForwardOutcome::Forwarded(res) => { @@ -1182,7 +1182,7 @@ async fn handle_get_safekeeper(req: Request) -> Result, Api /// Assumes information is only relayed to storage controller after first selecting an unique id on /// control plane database, which means we have an id field in the request and payload. async fn handle_upsert_safekeeper(mut req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let body = json_request::(&mut req).await?; let id = parse_request_param::(&req, "id")?; diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index 4864a021fe..9d2182d44c 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -22,7 +22,7 @@ use utils::sync::gate::GateGuard; use crate::compute_hook::{ComputeHook, NotifyError}; use crate::node::Node; -use crate::tenant_shard::{IntentState, ObservedState, ObservedStateLocation}; +use crate::tenant_shard::{IntentState, ObservedState, ObservedStateDelta, ObservedStateLocation}; const DEFAULT_HEATMAP_PERIOD: &str = "60s"; @@ -45,8 +45,15 @@ pub(super) struct Reconciler { pub(crate) reconciler_config: ReconcilerConfig, pub(crate) config: TenantConfig, + + /// Observed state from the point of view of the reconciler. + /// This gets updated as the reconciliation makes progress. pub(crate) observed: ObservedState, + /// Snapshot of the observed state at the point when the reconciler + /// was spawned. + pub(crate) original_observed: ObservedState, + pub(crate) service_config: service::Config, /// A hook to notify the running postgres instances when we change the location @@ -846,6 +853,39 @@ impl Reconciler { } } + /// Compare the observed state snapshot from when the reconcile was created + /// with the final observed state in order to generate observed state deltas. + pub(crate) fn observed_deltas(&self) -> Vec { + let mut deltas = Vec::default(); + + for (node_id, location) in &self.observed.locations { + let previous_location = self.original_observed.locations.get(node_id); + let do_upsert = match previous_location { + // Location config changed for node + Some(prev) if location.conf != prev.conf => true, + // New location config for node + None => true, + // Location config has not changed for node + _ => false, + }; + + if do_upsert { + deltas.push(ObservedStateDelta::Upsert(Box::new(( + *node_id, + location.clone(), + )))); + } + } + + for node_id in self.original_observed.locations.keys() { + if !self.observed.locations.contains_key(node_id) { + deltas.push(ObservedStateDelta::Delete(*node_id)); + } + } + + deltas + } + /// Keep trying to notify the compute indefinitely, only dropping out if: /// - the node `origin` becomes unavailable -> Ok(()) /// - the node `origin` no longer has our tenant shard attached -> Ok(()) diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 180ab5f0c5..cc735dc27e 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -28,8 +28,8 @@ use crate::{ reconciler::{ReconcileError, ReconcileUnits, ReconcilerConfig, ReconcilerConfigBuilder}, scheduler::{MaySchedule, ScheduleContext, ScheduleError, ScheduleMode}, tenant_shard::{ - MigrateAttachment, ReconcileNeeded, ReconcilerStatus, ScheduleOptimization, - ScheduleOptimizationAction, + MigrateAttachment, ObservedStateDelta, ReconcileNeeded, ReconcilerStatus, + ScheduleOptimization, ScheduleOptimizationAction, }, }; use anyhow::Context; @@ -966,6 +966,8 @@ impl Service { let res = self.heartbeater.heartbeat(nodes).await; if let Ok(deltas) = res { + let mut to_handle = Vec::default(); + for (node_id, state) in deltas.0 { let new_availability = match state { PageserverState::Available { utilization, .. } => { @@ -997,14 +999,27 @@ impl Service { } }; + let node_lock = trace_exclusive_lock( + &self.node_op_locks, + node_id, + NodeOperations::Configure, + ) + .await; + // This is the code path for geniune availability transitions (i.e node // goes unavailable and/or comes back online). let res = self - .node_configure(node_id, Some(new_availability), None) + .node_state_configure(node_id, Some(new_availability), None, &node_lock) .await; match res { - Ok(()) => {} + Ok(transition) => { + // Keep hold of the lock until the availability transitions + // have been handled in + // [`Service::handle_node_availability_transitions`] in order avoid + // racing with [`Service::external_node_configure`]. + to_handle.push((node_id, node_lock, transition)); + } Err(ApiError::NotFound(_)) => { // This should be rare, but legitimate since the heartbeats are done // on a snapshot of the nodes. @@ -1014,13 +1029,37 @@ impl Service { // Transition to active involves reconciling: if a node responds to a heartbeat then // becomes unavailable again, we may get an error here. tracing::error!( - "Failed to update node {} after heartbeat round: {}", + "Failed to update node state {} after heartbeat round: {}", node_id, err ); } } } + + // We collected all the transitions above and now we handle them. + let res = self.handle_node_availability_transitions(to_handle).await; + if let Err(errs) = res { + for (node_id, err) in errs { + match err { + ApiError::NotFound(_) => { + // This should be rare, but legitimate since the heartbeats are done + // on a snapshot of the nodes. + tracing::info!( + "Node {} was not found after heartbeat round", + node_id + ); + } + err => { + tracing::error!( + "Failed to handle availability transition for {} after heartbeat round: {}", + node_id, + err + ); + } + } + } + } } } } @@ -1033,7 +1072,7 @@ impl Service { tenant_id=%result.tenant_shard_id.tenant_id, shard_id=%result.tenant_shard_id.shard_slug(), sequence=%result.sequence ))] - fn process_result(&self, mut result: ReconcileResult) { + fn process_result(&self, result: ReconcileResult) { let mut locked = self.inner.write().unwrap(); let (nodes, tenants, _scheduler) = locked.parts_mut(); let Some(tenant) = tenants.get_mut(&result.tenant_shard_id) else { @@ -1055,22 +1094,27 @@ impl Service { // In case a node was deleted while this reconcile is in flight, filter it out of the update we will // make to the tenant - result - .observed - .locations - .retain(|node_id, _loc| nodes.contains_key(node_id)); + let deltas = result.observed_deltas.into_iter().flat_map(|delta| { + // In case a node was deleted while this reconcile is in flight, filter it out of the update we will + // make to the tenant + let node = nodes.get(delta.node_id())?; + + if node.is_available() { + return Some(delta); + } + + // In case a node became unavailable concurrently with the reconcile, observed + // locations on it are now uncertain. By convention, set them to None in order + // for them to get refreshed when the node comes back online. + Some(ObservedStateDelta::Upsert(Box::new(( + node.get_id(), + ObservedStateLocation { conf: None }, + )))) + }); match result.result { Ok(()) => { - for (node_id, loc) in &result.observed.locations { - if let Some(conf) = &loc.conf { - tracing::info!("Updating observed location {}: {:?}", node_id, conf); - } else { - tracing::info!("Setting observed location {} to None", node_id,) - } - } - - tenant.observed = result.observed; + tenant.apply_observed_deltas(deltas); tenant.waiter.advance(result.sequence); } Err(e) => { @@ -1092,9 +1136,10 @@ impl Service { // so that waiters will see the correct error after waiting. tenant.set_last_error(result.sequence, e); - for (node_id, o) in result.observed.locations { - tenant.observed.locations.insert(node_id, o); - } + // Skip deletions on reconcile failures + let upsert_deltas = + deltas.filter(|delta| matches!(delta, ObservedStateDelta::Upsert(_))); + tenant.apply_observed_deltas(upsert_deltas); } } @@ -5299,15 +5344,17 @@ impl Service { Ok(()) } - pub(crate) async fn node_configure( + /// Configure in-memory and persistent state of a node as requested + /// + /// Note that this function does not trigger any immediate side effects in response + /// to the changes. That part is handled by [`Self::handle_node_availability_transition`]. + async fn node_state_configure( &self, node_id: NodeId, availability: Option, scheduling: Option, - ) -> Result<(), ApiError> { - let _node_lock = - trace_exclusive_lock(&self.node_op_locks, node_id, NodeOperations::Configure).await; - + node_lock: &TracingExclusiveGuard, + ) -> Result { if let Some(scheduling) = scheduling { // Scheduling is a persistent part of Node: we must write updates to the database before // applying them in memory @@ -5336,7 +5383,7 @@ impl Service { }; if matches!(availability_transition, AvailabilityTransition::ToActive) { - self.node_activate_reconcile(activate_node, &_node_lock) + self.node_activate_reconcile(activate_node, node_lock) .await?; } availability_transition @@ -5346,7 +5393,7 @@ impl Service { // Apply changes from the request to our in-memory state for the Node let mut locked = self.inner.write().unwrap(); - let (nodes, tenants, scheduler) = locked.parts_mut(); + let (nodes, _tenants, scheduler) = locked.parts_mut(); let mut new_nodes = (**nodes).clone(); @@ -5356,8 +5403,8 @@ impl Service { )); }; - if let Some(availability) = availability.as_ref() { - node.set_availability(availability.clone()); + if let Some(availability) = availability { + node.set_availability(availability); } if let Some(scheduling) = scheduling { @@ -5368,11 +5415,30 @@ impl Service { scheduler.node_upsert(node); let new_nodes = Arc::new(new_nodes); + locked.nodes = new_nodes; + Ok(availability_transition) + } + + /// Handle availability transition of one node + /// + /// Note that you should first call [`Self::node_state_configure`] to update + /// the in-memory state referencing that node. If you need to handle more than one transition + /// consider using [`Self::handle_node_availability_transitions`]. + async fn handle_node_availability_transition( + &self, + node_id: NodeId, + transition: AvailabilityTransition, + _node_lock: &TracingExclusiveGuard, + ) -> Result<(), ApiError> { // Modify scheduling state for any Tenants that are affected by a change in the node's availability state. - match availability_transition { + match transition { AvailabilityTransition::ToOffline => { tracing::info!("Node {} transition to offline", node_id); + + let mut locked = self.inner.write().unwrap(); + let (nodes, tenants, scheduler) = locked.parts_mut(); + let mut tenants_affected: usize = 0; for (tenant_shard_id, tenant_shard) in tenants { @@ -5382,14 +5448,14 @@ impl Service { observed_loc.conf = None; } - if new_nodes.len() == 1 { + if nodes.len() == 1 { // Special case for single-node cluster: there is no point trying to reschedule // any tenant shards: avoid doing so, in order to avoid spewing warnings about // failures to schedule them. continue; } - if !new_nodes + if !nodes .values() .any(|n| matches!(n.may_schedule(), MaySchedule::Yes(_))) { @@ -5415,10 +5481,7 @@ impl Service { tracing::warn!(%tenant_shard_id, "Scheduling error when marking pageserver {} offline: {e}", node_id); } Ok(()) => { - if self - .maybe_reconcile_shard(tenant_shard, &new_nodes) - .is_some() - { + if self.maybe_reconcile_shard(tenant_shard, nodes).is_some() { tenants_affected += 1; }; } @@ -5433,9 +5496,13 @@ impl Service { } AvailabilityTransition::ToActive => { tracing::info!("Node {} transition to active", node_id); + + let mut locked = self.inner.write().unwrap(); + let (nodes, tenants, _scheduler) = locked.parts_mut(); + // When a node comes back online, we must reconcile any tenant that has a None observed // location on the node. - for tenant_shard in locked.tenants.values_mut() { + for tenant_shard in tenants.values_mut() { // If a reconciliation is already in progress, rely on the previous scheduling // decision and skip triggering a new reconciliation. if tenant_shard.reconciler.is_some() { @@ -5444,7 +5511,7 @@ impl Service { if let Some(observed_loc) = tenant_shard.observed.locations.get_mut(&node_id) { if observed_loc.conf.is_none() { - self.maybe_reconcile_shard(tenant_shard, &new_nodes); + self.maybe_reconcile_shard(tenant_shard, nodes); } } } @@ -5465,11 +5532,54 @@ impl Service { } } - locked.nodes = new_nodes; - Ok(()) } + /// Handle availability transition for multiple nodes + /// + /// Note that you should first call [`Self::node_state_configure`] for + /// all nodes being handled here for the handling to use fresh in-memory state. + async fn handle_node_availability_transitions( + &self, + transitions: Vec<( + NodeId, + TracingExclusiveGuard, + AvailabilityTransition, + )>, + ) -> Result<(), Vec<(NodeId, ApiError)>> { + let mut errors = Vec::default(); + for (node_id, node_lock, transition) in transitions { + let res = self + .handle_node_availability_transition(node_id, transition, &node_lock) + .await; + if let Err(err) = res { + errors.push((node_id, err)); + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } + } + + pub(crate) async fn node_configure( + &self, + node_id: NodeId, + availability: Option, + scheduling: Option, + ) -> Result<(), ApiError> { + let node_lock = + trace_exclusive_lock(&self.node_op_locks, node_id, NodeOperations::Configure).await; + + let transition = self + .node_state_configure(node_id, availability, scheduling, &node_lock) + .await?; + self.handle_node_availability_transition(node_id, transition, &node_lock) + .await + } + /// Wrapper around [`Self::node_configure`] which only allows changes while there is no ongoing /// operation for HTTP api. pub(crate) async fn external_node_configure( diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index 2e85580e08..8a7ff866e6 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -425,6 +425,22 @@ pub(crate) enum ReconcileNeeded { Yes, } +/// Pending modification to the observed state of a tenant shard. +/// Produced by [`Reconciler::observed_deltas`] and applied in [`crate::service::Service::process_result`]. +pub(crate) enum ObservedStateDelta { + Upsert(Box<(NodeId, ObservedStateLocation)>), + Delete(NodeId), +} + +impl ObservedStateDelta { + pub(crate) fn node_id(&self) -> &NodeId { + match self { + Self::Upsert(up) => &up.0, + Self::Delete(nid) => nid, + } + } +} + /// When a reconcile task completes, it sends this result object /// to be applied to the primary TenantShard. pub(crate) struct ReconcileResult { @@ -437,7 +453,7 @@ pub(crate) struct ReconcileResult { pub(crate) tenant_shard_id: TenantShardId, pub(crate) generation: Option, - pub(crate) observed: ObservedState, + pub(crate) observed_deltas: Vec, /// Set [`TenantShard::pending_compute_notification`] from this flag pub(crate) pending_compute_notification: bool, @@ -1123,7 +1139,7 @@ impl TenantShard { result, tenant_shard_id: reconciler.tenant_shard_id, generation: reconciler.generation, - observed: reconciler.observed, + observed_deltas: reconciler.observed_deltas(), pending_compute_notification: reconciler.compute_notify_failure, } } @@ -1177,6 +1193,7 @@ impl TenantShard { reconciler_config, config: self.config.clone(), observed: self.observed.clone(), + original_observed: self.observed.clone(), compute_hook: compute_hook.clone(), service_config: service_config.clone(), _gate_guard: gate_guard, @@ -1437,6 +1454,62 @@ impl TenantShard { .map(|(node_id, gen)| (node_id, Generation::new(gen))) .collect() } + + /// Update the observed state of the tenant by applying incremental deltas + /// + /// Deltas are generated by reconcilers via [`Reconciler::observed_deltas`]. + /// They are then filtered in [`crate::service::Service::process_result`]. + pub(crate) fn apply_observed_deltas( + &mut self, + deltas: impl Iterator, + ) { + for delta in deltas { + match delta { + ObservedStateDelta::Upsert(ups) => { + let (node_id, loc) = *ups; + + // If the generation of the observed location in the delta is lagging + // behind the current one, then we have a race condition and cannot + // be certain about the true observed state. Set the observed state + // to None in order to reflect this. + let crnt_gen = self + .observed + .locations + .get(&node_id) + .and_then(|loc| loc.conf.as_ref()) + .and_then(|conf| conf.generation); + let new_gen = loc.conf.as_ref().and_then(|conf| conf.generation); + match (crnt_gen, new_gen) { + (Some(crnt), Some(new)) if crnt_gen > new_gen => { + tracing::warn!( + "Skipping observed state update {}: {:?} and using None due to stale generation ({} > {})", + node_id, loc, crnt, new + ); + + self.observed + .locations + .insert(node_id, ObservedStateLocation { conf: None }); + + continue; + } + _ => {} + } + + if let Some(conf) = &loc.conf { + tracing::info!("Updating observed location {}: {:?}", node_id, conf); + } else { + tracing::info!("Setting observed location {} to None", node_id,) + } + + self.observed.locations.insert(node_id, loc); + } + ObservedStateDelta::Delete(node_id) => { + tracing::info!("Deleting observed location {}", node_id); + self.observed.locations.remove(&node_id); + } + } + } + } } #[cfg(test)] diff --git a/storage_scrubber/src/scan_pageserver_metadata.rs b/storage_scrubber/src/scan_pageserver_metadata.rs index c1ea589f7f..cb3299d413 100644 --- a/storage_scrubber/src/scan_pageserver_metadata.rs +++ b/storage_scrubber/src/scan_pageserver_metadata.rs @@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata( tenant_timeline_results.push((ttid, data)); } - let tenant_id = tenant_id.expect("Must be set if results are present"); - if !tenant_timeline_results.is_empty() { + let tenant_id = tenant_id.expect("Must be set if results are present"); analyze_tenant( &remote_client, tenant_id, diff --git a/test_runner/README.md b/test_runner/README.md index d754e60d17..e087241c1f 100644 --- a/test_runner/README.md +++ b/test_runner/README.md @@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan Useful environment variables: `NEON_BIN`: The directory where neon binaries can be found. +`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found `POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found. Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain a subdirectory for each version with naming convention `v{PG_VERSION}/`. Inside that dir, a `bin/postgres` binary should be present. +`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found. `DEFAULT_PG_VERSION`: The version of Postgres to use, This is used to construct full path to the postgres binaries. Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16` @@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder): client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id) ``` +All the test which rely on NeonEnvBuilder, can check the various version combinations of the components. +To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions() +E.g. + +```python +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_something( +... +``` + For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html At the end of a test, all the nodes in the environment are automatically stopped, so you diff --git a/test_runner/cloud_regress/test_cloud_regress.py b/test_runner/cloud_regress/test_cloud_regress.py index de71357232..715d4a4881 100644 --- a/test_runner/cloud_regress/test_cloud_regress.py +++ b/test_runner/cloud_regress/test_cloud_regress.py @@ -2,6 +2,8 @@ Run the regression tests on the cloud instance of Neon """ +from __future__ import annotations + from pathlib import Path from typing import Any diff --git a/test_runner/conftest.py b/test_runner/conftest.py index 996ca4d652..4a3194c691 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -1,9 +1,12 @@ +from __future__ import annotations + pytest_plugins = ( "fixtures.pg_version", "fixtures.parametrize", "fixtures.httpserver", "fixtures.compute_reconfigure", "fixtures.storage_controller_proxy", + "fixtures.paths", "fixtures.neon_fixtures", "fixtures.benchmark_fixture", "fixtures.pg_stats", diff --git a/test_runner/fixtures/__init__.py b/test_runner/fixtures/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/fixtures/__init__.py +++ b/test_runner/fixtures/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 0c36cd6ef7..74fe39ef53 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import calendar import dataclasses import enum @@ -8,9 +10,7 @@ import timeit from contextlib import contextmanager from datetime import datetime from pathlib import Path - -# Type-related stuff -from typing import Callable, ClassVar, Dict, Iterator, Optional +from typing import TYPE_CHECKING import allure import pytest @@ -23,6 +23,11 @@ from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonPageserver +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping + from typing import Callable, Optional + + """ This file contains fixtures for micro-benchmarks. @@ -136,20 +141,30 @@ class PgBenchRunResult: ) +# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 +# +# This used to be a class variable on PgBenchInitResult. However later versions +# of Python complain: +# +# ValueError: mutable default for field EXTRACTORS is not allowed: use default_factory +# +# When you do what the error tells you to do, it seems to fail our Python 3.9 +# test environment. So let's just move it to a private module constant, and move +# on. +_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = { + "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), + "create_tables": re.compile(r"create tables (\d+\.\d+) s"), + "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), + "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), + "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), + "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), + "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), + "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench +} + + @dataclasses.dataclass class PgBenchInitResult: - # Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 - EXTRACTORS: ClassVar[Dict[str, re.Pattern]] = { # type: ignore[type-arg] - "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), - "create_tables": re.compile(r"create tables (\d+\.\d+) s"), - "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), - "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), - "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), - "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), - "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), - "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench - } - total: Optional[float] drop_tables: Optional[float] create_tables: Optional[float] @@ -175,10 +190,10 @@ class PgBenchInitResult: last_line = stderr.splitlines()[-1] - timings: Dict[str, Optional[float]] = {} + timings: dict[str, Optional[float]] = {} last_line_items = re.split(r"\(|\)|,", last_line) for item in last_line_items: - for key, regex in cls.EXTRACTORS.items(): + for key, regex in _PGBENCH_INIT_EXTRACTORS.items(): if (m := regex.match(item.strip())) is not None: if key in timings: raise RuntimeError( @@ -385,7 +400,7 @@ class NeonBenchmarker: self, pageserver: NeonPageserver, metric_name: str, - label_filters: Optional[Dict[str, str]] = None, + label_filters: Optional[dict[str, str]] = None, ) -> int: """Fetch the value of given int counter from pageserver metrics.""" all_metrics = pageserver.http_client().get_metrics() diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index d8390138c9..0ea7148f50 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import random from dataclasses import dataclass from enum import Enum from functools import total_ordering -from typing import Any, Dict, Type, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar + +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Any, Union + + T = TypeVar("T", bound="Id") -T = TypeVar("T", bound="Id") DEFAULT_WAL_SEG_SIZE = 16 * 1024 * 1024 @@ -25,38 +33,41 @@ class Lsn: self.lsn_int = (int(left, 16) << 32) + int(right, 16) assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF + @override def __str__(self) -> str: """Convert lsn from int to standard hex notation.""" return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}" + @override def __repr__(self) -> str: return f'Lsn("{str(self)}")' def __int__(self) -> int: return self.lsn_int - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int < other.lsn_int - def __gt__(self, other: Any) -> bool: + def __gt__(self, other: object) -> bool: if not isinstance(other, Lsn): raise NotImplementedError return self.lsn_int > other.lsn_int - def __eq__(self, other: Any) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int == other.lsn_int # Returns the difference between two Lsns, in bytes - def __sub__(self, other: Any) -> int: + def __sub__(self, other: object) -> int: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int - other.lsn_int - def __add__(self, other: Union[int, "Lsn"]) -> "Lsn": + def __add__(self, other: Union[int, Lsn]) -> Lsn: if isinstance(other, int): return Lsn(self.lsn_int + other) elif isinstance(other, Lsn): @@ -64,13 +75,14 @@ class Lsn: else: raise NotImplementedError + @override def __hash__(self) -> int: return hash(self.lsn_int) def as_int(self) -> int: return self.lsn_int - def segment_lsn(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> "Lsn": + def segment_lsn(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> Lsn: return Lsn(self.lsn_int - (self.lsn_int % seg_sz)) def segno(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> int: @@ -110,48 +122,57 @@ class Id: self.id = bytearray.fromhex(x) assert len(self.id) == 16 + @override def __str__(self) -> str: return self.id.hex() - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id < other.id - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id == other.id + @override def __hash__(self) -> int: return hash(str(self.id)) @classmethod - def generate(cls: Type[T]) -> T: + def generate(cls: type[T]) -> T: """Generate a random ID""" return cls(random.randbytes(16).hex()) class TenantId(Id): + @override def __repr__(self) -> str: return f'`TenantId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class NodeId(Id): + @override def __repr__(self) -> str: return f'`NodeId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class TimelineId(Id): + @override def __repr__(self) -> str: return f'TimelineId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() @@ -162,7 +183,7 @@ class TenantTimelineId: timeline_id: TimelineId @classmethod - def from_json(cls, d: Dict[str, Any]) -> "TenantTimelineId": + def from_json(cls, d: dict[str, Any]) -> TenantTimelineId: return TenantTimelineId( tenant_id=TenantId(d["tenant_id"]), timeline_id=TimelineId(d["timeline_id"]), @@ -181,7 +202,7 @@ class TenantShardId: assert self.shard_number < self.shard_count or self.shard_count == 0 @classmethod - def parse(cls: Type[TTenantShardId], input) -> TTenantShardId: + def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId: if len(input) == 32: return cls( tenant_id=TenantId(input), @@ -197,6 +218,7 @@ class TenantShardId: else: raise ValueError(f"Invalid TenantShardId '{input}'") + @override def __str__(self): if self.shard_count > 0: return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}" @@ -204,22 +226,25 @@ class TenantShardId: # Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id) return str(self.tenant_id) + @override def __repr__(self): return self.__str__() def _tuple(self) -> tuple[TenantId, int, int]: return (self.tenant_id, self.shard_number, self.shard_count) - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() < other._tuple() - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() == other._tuple() + @override def __hash__(self) -> int: return hash(self._tuple()) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index fb9c2d2b86..2195ae8225 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import os import time from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import _GeneratorContextManager, contextmanager # Type-related stuff from pathlib import Path -from typing import Dict, Iterator, List +from typing import TYPE_CHECKING import pytest from _pytest.fixtures import FixtureRequest +from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log @@ -22,6 +26,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pg_stats import PgStatTable +if TYPE_CHECKING: + from collections.abc import Iterator + class PgCompare(ABC): """Common interface of all postgres implementations, useful for benchmarks. @@ -63,16 +70,16 @@ class PgCompare(ABC): @contextmanager @abstractmethod - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str): pass @contextmanager @abstractmethod - def record_duration(self, out_name): + def record_duration(self, out_name: str): pass @contextmanager - def record_pg_stats(self, pg_stats: List[PgStatTable]) -> Iterator[None]: + def record_pg_stats(self, pg_stats: list[PgStatTable]) -> Iterator[None]: init_data = self._retrieve_pg_stats(pg_stats) yield @@ -82,8 +89,8 @@ class PgCompare(ABC): for k in set(init_data) & set(data): self.zenbenchmark.record(k, data[k] - init_data[k], "", MetricReport.HIGHER_IS_BETTER) - def _retrieve_pg_stats(self, pg_stats: List[PgStatTable]) -> Dict[str, int]: - results: Dict[str, int] = {} + def _retrieve_pg_stats(self, pg_stats: list[PgStatTable]) -> dict[str, int]: + results: dict[str, int] = {} with self.pg.connect().cursor() as cur: for pg_stat in pg_stats: @@ -120,28 +127,34 @@ class NeonCompare(PgCompare): self._pg = self.env.endpoints.create_start("main", "main", self.tenant) @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg_bin + @override def flush(self, compact: bool = True, gc: bool = True): wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline) self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact) if gc: self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0) + @override def compact(self): self.pageserver_http_client.timeline_compact( self.tenant, self.timeline, wait_until_uploaded=True ) + @override def report_peak_memory_use(self): self.zenbenchmark.record( "peak_mem", @@ -150,6 +163,7 @@ class NeonCompare(PgCompare): report=MetricReport.LOWER_IS_BETTER, ) + @override def report_size(self): timeline_size = self.zenbenchmark.get_timeline_size( self.env.repo_dir, self.tenant, self.timeline @@ -183,9 +197,11 @@ class NeonCompare(PgCompare): "num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER ) + @override def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -209,26 +225,33 @@ class VanillaCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> VanillaPostgres: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin + @override def flush(self, compact: bool = False, gc: bool = False): self.cur.execute("checkpoint") + @override def compact(self): pass + @override def report_peak_memory_use(self): pass # TODO find something + @override def report_size(self): data_size = self.pg.get_subdir_size(Path("base")) self.zenbenchmark.record( @@ -243,6 +266,7 @@ class VanillaCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -259,28 +283,35 @@ class RemoteCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin - def flush(self): + @override + def flush(self, compact: bool = False, gc: bool = False): # TODO: flush the remote pageserver pass + @override def compact(self): pass + @override def report_peak_memory_use(self): # TODO: get memory usage from remote pageserver pass + @override def report_size(self): # TODO: get storage size from remote pageserver pass @@ -289,6 +320,7 @@ class RemoteCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) diff --git a/test_runner/fixtures/compute_reconfigure.py b/test_runner/fixtures/compute_reconfigure.py index 66fc35b6aa..6354b7f833 100644 --- a/test_runner/fixtures/compute_reconfigure.py +++ b/test_runner/fixtures/compute_reconfigure.py @@ -1,25 +1,31 @@ +from __future__ import annotations + import concurrent.futures -from typing import Any +from typing import TYPE_CHECKING import pytest +from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response from fixtures.common_types import TenantId from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Callable, Optional + class ComputeReconfigure: - def __init__(self, server): + def __init__(self, server: HTTPServer): self.server = server self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach" - self.workloads = {} - self.on_notify = None + self.workloads: dict[TenantId, Any] = {} + self.on_notify: Optional[Callable[[Any], None]] = None - def register_workload(self, workload): + def register_workload(self, workload: Any): self.workloads[workload.tenant_id] = workload - def register_on_notify(self, fn): + def register_on_notify(self, fn: Optional[Callable[[Any], None]]): """ Add some extra work during a notification, like sleeping to slow things down, or logging what was notified. @@ -28,7 +34,7 @@ class ComputeReconfigure: @pytest.fixture(scope="function") -def compute_reconfigure_listener(make_httpserver): +def compute_reconfigure_listener(make_httpserver: HTTPServer): """ This fixture exposes an HTTP listener for the storage controller to submit compute notifications to us, instead of updating neon_local endpoints itself. @@ -46,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver): # accept a healthy rate of calls into notify-attach. reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1) - def handler(request: Request): + def handler(request: Request) -> Response: assert request.json is not None body: dict[str, Any] = request.json log.info(f"notify-attach request: {body}") diff --git a/test_runner/fixtures/endpoint/__init__.py b/test_runner/fixtures/endpoint/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/fixtures/endpoint/__init__.py +++ b/test_runner/fixtures/endpoint/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 42f0539c19..26895df8a6 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import requests from requests.adapters import HTTPAdapter @@ -21,3 +23,8 @@ class EndpointHttpClient(requests.Session): res = self.get(f"http://localhost:{self.port}/database_schema?database={database}") res.raise_for_status() return res.text + + def installed_extensions(self): + res = self.get(f"http://localhost:{self.port}/installed_extensions") + res.raise_for_status() + return res.json() diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py index d13f3318b0..01634a29c5 100644 --- a/test_runner/fixtures/flaky.py +++ b/test_runner/fixtures/flaky.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import json +from collections.abc import MutableMapping from pathlib import Path -from typing import Any, List, MutableMapping, cast +from typing import TYPE_CHECKING, cast import pytest from _pytest.config import Config @@ -10,6 +13,11 @@ from allure_pytest.utils import allure_name, allure_suite_labels from fixtures.log_helper import log +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import Any + + """ The plugin reruns flaky tests. It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` @@ -27,7 +35,7 @@ def pytest_addoption(parser: Parser): ) -def pytest_collection_modifyitems(config: Config, items: List[pytest.Item]): +def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]): if not config.getoption("--flaky-tests-json"): return @@ -66,5 +74,5 @@ def pytest_collection_modifyitems(config: Config, items: List[pytest.Item]): # - [2] https://github.com/pytest-dev/pytest-timeout/issues/142 timeout_marker = item.get_closest_marker("timeout") if timeout_marker is not None: - kwargs = cast(MutableMapping[str, Any], timeout_marker.kwargs) + kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs) kwargs["func_only"] = True diff --git a/test_runner/fixtures/httpserver.py b/test_runner/fixtures/httpserver.py index a321d59266..f653fd804c 100644 --- a/test_runner/fixtures/httpserver.py +++ b/test_runner/fixtures/httpserver.py @@ -1,8 +1,15 @@ -from typing import Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest from pytest_httpserver import HTTPServer +if TYPE_CHECKING: + from collections.abc import Iterator + + from fixtures.port_distributor import PortDistributor + # TODO: mypy fails with: # Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined] # from fixtures.neon_fixtures import PortDistributor @@ -17,7 +24,7 @@ def httpserver_ssl_context(): @pytest.fixture(scope="function") -def make_httpserver(httpserver_listen_address, httpserver_ssl_context): +def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]: host, port = httpserver_listen_address if not host: host = HTTPServer.DEFAULT_LISTEN_HOST @@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context): @pytest.fixture(scope="function") -def httpserver(make_httpserver): +def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]: server = make_httpserver yield server server.clear() @pytest.fixture(scope="function") -def httpserver_listen_address(port_distributor) -> Tuple[str, int]: +def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]: port = port_distributor.get_port() return ("localhost", port) diff --git a/test_runner/fixtures/log_helper.py b/test_runner/fixtures/log_helper.py index 17f2402391..ebf5c8d803 100644 --- a/test_runner/fixtures/log_helper.py +++ b/test_runner/fixtures/log_helper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import logging.config @@ -29,7 +31,7 @@ LOGGING = { } -def getLogger(name="root") -> logging.Logger: +def getLogger(name: str = "root") -> logging.Logger: """Method to get logger for tests. Should be used to get correctly initialized logger.""" diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index 005dc6cb0d..e056ea77d4 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -1,23 +1,28 @@ +from __future__ import annotations + from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING from prometheus_client.parser import text_string_to_metric_families from prometheus_client.samples import Sample from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Optional + class Metrics: - metrics: Dict[str, List[Sample]] + metrics: dict[str, list[Sample]] name: str def __init__(self, name: str = ""): self.metrics = defaultdict(list) self.name = name - def query_all(self, name: str, filter: Optional[Dict[str, str]] = None) -> List[Sample]: + def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]: filter = filter or {} - res = [] + res: list[Sample] = [] for sample in self.metrics[name]: try: @@ -27,7 +32,7 @@ class Metrics: pass return res - def query_one(self, name: str, filter: Optional[Dict[str, str]] = None) -> Sample: + def query_one(self, name: str, filter: Optional[dict[str, str]] = None) -> Sample: res = self.query_all(name, filter or {}) assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}" return res[0] @@ -43,7 +48,7 @@ class MetricsGetter: raise NotImplementedError() def get_metric_value( - self, name: str, filter: Optional[Dict[str, str]] = None + self, name: str, filter: Optional[dict[str, str]] = None ) -> Optional[float]: metrics = self.get_metrics() results = metrics.query_all(name, filter=filter) @@ -54,8 +59,8 @@ class MetricsGetter: return results[0].value def get_metrics_values( - self, names: list[str], filter: Optional[Dict[str, str]] = None, absence_ok=False - ) -> Dict[str, float]: + self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False + ) -> dict[str, float]: """ When fetching multiple named metrics, it is more efficient to use this than to call `get_metric_value` repeatedly. @@ -97,7 +102,7 @@ def parse_metrics(text: str, name: str = "") -> Metrics: return metrics -def histogram(prefix_without_trailing_underscore: str) -> List[str]: +def histogram(prefix_without_trailing_underscore: str) -> list[str]: assert not prefix_without_trailing_underscore.endswith("_") return [f"{prefix_without_trailing_underscore}_{x}" for x in ["bucket", "count", "sum"]] @@ -107,7 +112,7 @@ def counter(name: str) -> str: return f"{name}_total" -PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: Tuple[str, ...] = ( +PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: tuple[str, ...] = ( "pageserver_remote_timeline_client_calls_started_total", "pageserver_remote_timeline_client_calls_finished_total", "pageserver_remote_physical_size", @@ -115,7 +120,7 @@ PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: Tuple[str, ...] = ( "pageserver_remote_timeline_client_bytes_finished_total", ) -PAGESERVER_GLOBAL_METRICS: Tuple[str, ...] = ( +PAGESERVER_GLOBAL_METRICS: tuple[str, ...] = ( "pageserver_storage_operations_seconds_global_count", "pageserver_storage_operations_seconds_global_sum", "pageserver_storage_operations_seconds_global_bucket", @@ -147,7 +152,7 @@ PAGESERVER_GLOBAL_METRICS: Tuple[str, ...] = ( counter("pageserver_tenant_throttling_count_global"), ) -PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = ( +PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = ( "pageserver_current_logical_size", "pageserver_resident_physical_size", "pageserver_io_operations_bytes_total", diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 0636cfad06..5934baccff 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, cast import requests if TYPE_CHECKING: - from typing import Any, Dict, Literal, Optional, Union + from typing import Any, Literal, Optional from fixtures.pg_version import PgVersion -def connection_parameters_to_env(params: Dict[str, str]) -> Dict[str, str]: +def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]: return { "PGHOST": params["host"], "PGDATABASE": params["database"], @@ -25,9 +25,7 @@ class NeonAPI: self.__neon_api_key = neon_api_key self.__neon_api_base_url = neon_api_base_url.strip("/") - def __request( - self, method: Union[str, bytes], endpoint: str, **kwargs: Any - ) -> requests.Response: + def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response: if "headers" not in kwargs: kwargs["headers"] = {} kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}" @@ -41,8 +39,8 @@ class NeonAPI: branch_name: Optional[str] = None, branch_role_name: Optional[str] = None, branch_database_name: Optional[str] = None, - ) -> Dict[str, Any]: - data: Dict[str, Any] = { + ) -> dict[str, Any]: + data: dict[str, Any] = { "project": { "branch": {}, }, @@ -70,9 +68,9 @@ class NeonAPI: assert resp.status_code == 201 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) - def get_project_details(self, project_id: str) -> Dict[str, Any]: + def get_project_details(self, project_id: str) -> dict[str, Any]: resp = self.__request( "GET", f"/projects/{project_id}", @@ -82,12 +80,12 @@ class NeonAPI: }, ) assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def delete_project( self, project_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: resp = self.__request( "DELETE", f"/projects/{project_id}", @@ -99,13 +97,13 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def start_endpoint( self, project_id: str, endpoint_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: resp = self.__request( "POST", f"/projects/{project_id}/endpoints/{endpoint_id}/start", @@ -116,13 +114,13 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def suspend_endpoint( self, project_id: str, endpoint_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: resp = self.__request( "POST", f"/projects/{project_id}/endpoints/{endpoint_id}/suspend", @@ -133,13 +131,13 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def restart_endpoint( self, project_id: str, endpoint_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: resp = self.__request( "POST", f"/projects/{project_id}/endpoints/{endpoint_id}/restart", @@ -150,16 +148,16 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def create_endpoint( self, project_id: str, branch_id: str, endpoint_type: Literal["read_write", "read_only"], - settings: Dict[str, Any], - ) -> Dict[str, Any]: - data: Dict[str, Any] = { + settings: dict[str, Any], + ) -> dict[str, Any]: + data: dict[str, Any] = { "endpoint": { "branch_id": branch_id, }, @@ -182,7 +180,7 @@ class NeonAPI: assert resp.status_code == 201 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def get_connection_uri( self, @@ -192,7 +190,7 @@ class NeonAPI: database_name: str = "neondb", role_name: str = "neondb_owner", pooled: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: resp = self.__request( "GET", f"/projects/{project_id}/connection_uri", @@ -210,9 +208,9 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) - def get_branches(self, project_id: str) -> Dict[str, Any]: + def get_branches(self, project_id: str) -> dict[str, Any]: resp = self.__request( "GET", f"/projects/{project_id}/branches", @@ -223,9 +221,9 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) - def get_endpoints(self, project_id: str) -> Dict[str, Any]: + def get_endpoints(self, project_id: str) -> dict[str, Any]: resp = self.__request( "GET", f"/projects/{project_id}/endpoints", @@ -236,9 +234,9 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) - def get_operations(self, project_id: str) -> Dict[str, Any]: + def get_operations(self, project_id: str) -> dict[str, Any]: resp = self.__request( "GET", f"/projects/{project_id}/operations", @@ -250,7 +248,7 @@ class NeonAPI: assert resp.status_code == 200 - return cast("Dict[str, Any]", resp.json()) + return cast("dict[str, Any]", resp.json()) def wait_for_operation_to_finish(self, project_id: str): has_running = True diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index c27d22620e..0d3dcd1671 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -9,15 +9,7 @@ import tempfile import textwrap from itertools import chain, product from pathlib import Path -from typing import ( - Any, - Dict, - List, - Optional, - Tuple, - TypeVar, - cast, -) +from typing import TYPE_CHECKING, cast import toml @@ -27,7 +19,15 @@ from fixtures.pageserver.common_types import IndexPartDump from fixtures.pg_version import PgVersion from fixtures.utils import AuxFileStore -T = TypeVar("T") +if TYPE_CHECKING: + from typing import ( + Any, + Optional, + TypeVar, + cast, + ) + + T = TypeVar("T") class AbstractNeonCli(abc.ABC): @@ -37,7 +37,7 @@ class AbstractNeonCli(abc.ABC): Do not use directly, use specific subclasses instead. """ - def __init__(self, extra_env: Optional[Dict[str, str]], binpath: Path): + def __init__(self, extra_env: Optional[dict[str, str]], binpath: Path): self.extra_env = extra_env self.binpath = binpath @@ -45,11 +45,11 @@ class AbstractNeonCli(abc.ABC): def raw_cli( self, - arguments: List[str], - extra_env_vars: Optional[Dict[str, str]] = None, + arguments: list[str], + extra_env_vars: Optional[dict[str, str]] = None, check_return_code=True, timeout=None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: """ Run the command with the specified arguments. @@ -92,9 +92,8 @@ class AbstractNeonCli(abc.ABC): args, env=env_vars, check=False, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + text=True, + capture_output=True, timeout=timeout, ) except subprocess.TimeoutExpired as e: @@ -118,7 +117,7 @@ class AbstractNeonCli(abc.ABC): if len(lines) < 2: log.debug(f"Run {res.args} success: {stripped}") else: - log.debug("Run %s success:\n%s" % (res.args, textwrap.indent(stripped, indent))) + log.debug("Run %s success:\n%s", res.args, textwrap.indent(stripped, indent)) elif check_return_code: # this way command output will be in recorded and shown in CI in failure message indent = indent * 2 @@ -175,7 +174,7 @@ class NeonLocalCli(AbstractNeonCli): def __init__( self, - extra_env: Optional[Dict[str, str]], + extra_env: Optional[dict[str, str]], binpath: Path, repo_dir: Path, pg_distrib_dir: Path, @@ -197,7 +196,7 @@ class NeonLocalCli(AbstractNeonCli): tenant_id: TenantId, timeline_id: TimelineId, pg_version: PgVersion, - conf: Optional[Dict[str, Any]] = None, + conf: Optional[dict[str, Any]] = None, shard_count: Optional[int] = None, shard_stripe_size: Optional[int] = None, placement_policy: Optional[str] = None, @@ -258,7 +257,7 @@ class NeonLocalCli(AbstractNeonCli): res = self.raw_cli(["tenant", "set-default", "--tenant-id", str(tenant_id)]) res.check_returncode() - def tenant_config(self, tenant_id: TenantId, conf: Dict[str, str]): + def tenant_config(self, tenant_id: TenantId, conf: dict[str, str]): """ Update tenant config. """ @@ -274,7 +273,7 @@ class NeonLocalCli(AbstractNeonCli): res = self.raw_cli(args) res.check_returncode() - def tenant_list(self) -> "subprocess.CompletedProcess[str]": + def tenant_list(self) -> subprocess.CompletedProcess[str]: res = self.raw_cli(["tenant", "list"]) res.check_returncode() return res @@ -368,7 +367,7 @@ class NeonLocalCli(AbstractNeonCli): res = self.raw_cli(cmd) res.check_returncode() - def timeline_list(self, tenant_id: TenantId) -> List[Tuple[str, TimelineId]]: + def timeline_list(self, tenant_id: TenantId) -> list[tuple[str, TimelineId]]: """ Returns a list of (branch_name, timeline_id) tuples out of parsed `neon timeline list` CLI output. """ @@ -389,9 +388,9 @@ class NeonLocalCli(AbstractNeonCli): def init( self, - init_config: Dict[str, Any], + init_config: dict[str, Any], force: Optional[str] = None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: with tempfile.NamedTemporaryFile(mode="w+") as init_config_tmpfile: init_config_tmpfile.write(toml.dumps(init_config)) init_config_tmpfile.flush() @@ -434,29 +433,28 @@ class NeonLocalCli(AbstractNeonCli): def pageserver_start( self, id: int, - extra_env_vars: Optional[Dict[str, str]] = None, + extra_env_vars: Optional[dict[str, str]] = None, timeout_in_seconds: Optional[int] = None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: start_args = ["pageserver", "start", f"--id={id}"] if timeout_in_seconds is not None: start_args.append(f"--start-timeout={timeout_in_seconds}s") return self.raw_cli(start_args, extra_env_vars=extra_env_vars) - def pageserver_stop(self, id: int, immediate=False) -> "subprocess.CompletedProcess[str]": + def pageserver_stop(self, id: int, immediate=False) -> subprocess.CompletedProcess[str]: cmd = ["pageserver", "stop", f"--id={id}"] if immediate: cmd.extend(["-m", "immediate"]) - log.info(f"Stopping pageserver with {cmd}") return self.raw_cli(cmd) def safekeeper_start( self, id: int, - extra_opts: Optional[List[str]] = None, - extra_env_vars: Optional[Dict[str, str]] = None, + extra_opts: Optional[list[str]] = None, + extra_env_vars: Optional[dict[str, str]] = None, timeout_in_seconds: Optional[int] = None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: if extra_opts is not None: extra_opts = [f"-e={opt}" for opt in extra_opts] else: @@ -469,7 +467,7 @@ class NeonLocalCli(AbstractNeonCli): def safekeeper_stop( self, id: Optional[int] = None, immediate=False - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: args = ["safekeeper", "stop"] if id is not None: args.append(str(id)) @@ -479,13 +477,13 @@ class NeonLocalCli(AbstractNeonCli): def storage_broker_start( self, timeout_in_seconds: Optional[int] = None - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: cmd = ["storage_broker", "start"] if timeout_in_seconds is not None: cmd.append(f"--start-timeout={timeout_in_seconds}s") return self.raw_cli(cmd) - def storage_broker_stop(self) -> "subprocess.CompletedProcess[str]": + def storage_broker_stop(self) -> subprocess.CompletedProcess[str]: cmd = ["storage_broker", "stop"] return self.raw_cli(cmd) @@ -501,7 +499,7 @@ class NeonLocalCli(AbstractNeonCli): lsn: Optional[Lsn] = None, pageserver_id: Optional[int] = None, allow_multiple=False, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", "create", @@ -534,12 +532,12 @@ class NeonLocalCli(AbstractNeonCli): def endpoint_start( self, endpoint_id: str, - safekeepers: Optional[List[int]] = None, + safekeepers: Optional[list[int]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, allow_multiple=False, basebackup_request_tries: Optional[int] = None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", "start", @@ -568,9 +566,9 @@ class NeonLocalCli(AbstractNeonCli): endpoint_id: str, tenant_id: Optional[TenantId] = None, pageserver_id: Optional[int] = None, - safekeepers: Optional[List[int]] = None, + safekeepers: Optional[list[int]] = None, check_return_code=True, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: args = ["endpoint", "reconfigure", endpoint_id] if tenant_id is not None: args.extend(["--tenant-id", str(tenant_id)]) @@ -586,7 +584,7 @@ class NeonLocalCli(AbstractNeonCli): destroy=False, check_return_code=True, mode: Optional[str] = None, - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", "stop", @@ -602,7 +600,7 @@ class NeonLocalCli(AbstractNeonCli): def mappings_map_branch( self, name: str, tenant_id: TenantId, timeline_id: TimelineId - ) -> "subprocess.CompletedProcess[str]": + ) -> subprocess.CompletedProcess[str]: """ Map tenant id and timeline id to a neon_local branch name. They do not have to exist. Usually needed when creating branches via PageserverHttpClient and not neon_local. @@ -623,10 +621,10 @@ class NeonLocalCli(AbstractNeonCli): return self.raw_cli(args, check_return_code=True) - def start(self, check_return_code=True) -> "subprocess.CompletedProcess[str]": + def start(self, check_return_code=True) -> subprocess.CompletedProcess[str]: return self.raw_cli(["start"], check_return_code=check_return_code) - def stop(self, check_return_code=True) -> "subprocess.CompletedProcess[str]": + def stop(self, check_return_code=True) -> subprocess.CompletedProcess[str]: return self.raw_cli(["stop"], check_return_code=check_return_code) @@ -638,7 +636,7 @@ class WalCraft(AbstractNeonCli): COMMAND = "wal_craft" - def postgres_config(self) -> List[str]: + def postgres_config(self) -> list[str]: res = self.raw_cli(["print-postgres-config"]) res.check_returncode() return res.stdout.split("\n") diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index df88af88ed..7789855fe4 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -13,28 +13,15 @@ import threading import time import uuid from collections import defaultdict +from collections.abc import Iterable, Iterator from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime from enum import Enum -from fcntl import LOCK_EX, LOCK_UN, flock from functools import cached_property from pathlib import Path from types import TracebackType -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, cast from urllib.parse import quote, urlparse import asyncpg @@ -71,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import ( wait_for_last_record_lsn, ) +from fixtures.paths import get_test_repo_dir, shared_snapshot_dir from fixtures.pg_version import PgVersion from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import ( @@ -87,11 +75,10 @@ from fixtures.safekeeper.http import SafekeeperHttpClient from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, + COMPONENT_BINARIES, allure_add_grafana_links, - allure_attach_from_dir, assert_no_errors, get_dir_size, - get_self_dir, print_gc_result, subprocess_capture, wait_until, @@ -100,7 +87,19 @@ from fixtures.utils import AuxFileStore as AuxFileStore # reexport from .neon_api import NeonAPI, NeonApiEndpoint -T = TypeVar("T") +if TYPE_CHECKING: + from typing import ( + Any, + Callable, + Optional, + TypeVar, + Union, + ) + + from fixtures.paths import SnapshotDirLocked + + T = TypeVar("T") + """ This file contains pytest fixtures. A fixture is a test resource that can be @@ -119,67 +118,13 @@ Don't import functions from this file, or pytest will emit warnings. Instead put directly-importable functions into utils.py or another separate file. """ -Env = Dict[str, str] +Env = dict[str, str] -DEFAULT_OUTPUT_DIR: str = "test_output" DEFAULT_BRANCH_NAME: str = "main" BASE_PORT: int = 15000 -@pytest.fixture(scope="session") -def base_dir() -> Iterator[Path]: - # find the base directory (currently this is the git root) - base_dir = get_self_dir().parent.parent - log.info(f"base_dir is {base_dir}") - - yield base_dir - - -@pytest.fixture(scope="function") -def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: - if os.getenv("REMOTE_ENV"): - # we are in remote env and do not have neon binaries locally - # this is the case for benchmarks run on self-hosted runner - return - - # Find the neon binaries. - if env_neon_bin := os.environ.get("NEON_BIN"): - binpath = Path(env_neon_bin) - else: - binpath = base_dir / "target" / build_type - log.info(f"neon_binpath is {binpath}") - - if not (binpath / "pageserver").exists(): - raise Exception(f"neon binaries not found at '{binpath}'") - - yield binpath - - -@pytest.fixture(scope="session") -def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: - if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): - distrib_dir = Path(env_postgres_bin).resolve() - else: - distrib_dir = base_dir / "pg_install" - - log.info(f"pg_distrib_dir is {distrib_dir}") - yield distrib_dir - - -@pytest.fixture(scope="session") -def top_output_dir(base_dir: Path) -> Iterator[Path]: - # Compute the top-level directory for all tests. - if env_test_output := os.environ.get("TEST_OUTPUT"): - output_dir = Path(env_test_output).resolve() - else: - output_dir = base_dir / DEFAULT_OUTPUT_DIR - output_dir.mkdir(exist_ok=True) - - log.info(f"top_output_dir is {output_dir}") - yield output_dir - - @pytest.fixture(scope="session") def neon_api_key() -> str: api_key = os.getenv("NEON_API_KEY") @@ -251,7 +196,7 @@ class PgProtocol: """ return str(make_dsn(**self.conn_options(**kwargs))) - def conn_options(self, **kwargs: Any) -> Dict[str, Any]: + def conn_options(self, **kwargs: Any) -> dict[str, Any]: """ Construct a dictionary of connection options from default values and extra parameters. An option can be dropped from the returning dictionary by None-valued extra parameter. @@ -320,7 +265,7 @@ class PgProtocol: conn_options["server_settings"] = {key: val} return await asyncpg.connect(**conn_options) - def safe_psql(self, query: str, **kwargs: Any) -> List[Tuple[Any, ...]]: + def safe_psql(self, query: str, **kwargs: Any) -> list[tuple[Any, ...]]: """ Execute query against the node and return all rows. This method passes all extra params to connstr. @@ -329,12 +274,12 @@ class PgProtocol: def safe_psql_many( self, queries: Iterable[str], log_query=True, **kwargs: Any - ) -> List[List[Tuple[Any, ...]]]: + ) -> list[list[tuple[Any, ...]]]: """ Execute queries against the node and return all rows. This method passes all extra params to connstr. """ - result: List[List[Any]] = [] + result: list[list[Any]] = [] with closing(self.connect(**kwargs)) as conn: with conn.cursor() as cur: for query in queries: @@ -372,15 +317,18 @@ class NeonEnvBuilder: run_id: uuid.UUID, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, test_name: str, top_output_dir: Path, test_output_dir: Path, + combination, test_overlay_dir: Optional[Path] = None, pageserver_remote_storage: Optional[RemoteStorage] = None, # toml that will be decomposed into `--config-override` flags during `pageserver --init` - pageserver_config_override: Optional[str | Callable[[Dict[str, Any]], None]] = None, + pageserver_config_override: Optional[str | Callable[[dict[str, Any]], None]] = None, num_safekeepers: int = 1, num_pageservers: int = 1, # Use non-standard SK ids to check for various parsing bugs @@ -395,10 +343,10 @@ class NeonEnvBuilder: initial_timeline: Optional[TimelineId] = None, pageserver_virtual_file_io_engine: Optional[str] = None, pageserver_aux_file_policy: Optional[AuxFileStore] = None, - pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = None, + pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]] = None, safekeeper_extra_opts: Optional[list[str]] = None, storage_controller_port_override: Optional[int] = None, - pageserver_io_buffer_alignment: Optional[int] = None, + pageserver_virtual_file_io_mode: Optional[str] = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -430,7 +378,7 @@ class NeonEnvBuilder: self.enable_scrub_on_exit = True self.test_output_dir = test_output_dir self.test_overlay_dir = test_overlay_dir - self.overlay_mounts_created_by_us: List[Tuple[str, Path]] = [] + self.overlay_mounts_created_by_us: list[tuple[str, Path]] = [] self.config_init_force: Optional[str] = None self.top_output_dir = top_output_dir self.control_plane_compute_hook_api: Optional[str] = None @@ -439,7 +387,7 @@ class NeonEnvBuilder: self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine self.pageserver_default_tenant_config_compaction_algorithm: Optional[ - Dict[str, Any] + dict[str, Any] ] = pageserver_default_tenant_config_compaction_algorithm if self.pageserver_default_tenant_config_compaction_algorithm is not None: log.debug( @@ -452,12 +400,25 @@ class NeonEnvBuilder: self.storage_controller_port_override = storage_controller_port_override - self.pageserver_io_buffer_alignment = pageserver_io_buffer_alignment + self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode assert test_name.startswith( "test_" ), "Unexpectedly instantiated from outside a test function" self.test_name = test_name + self.compatibility_neon_binpath = compatibility_neon_binpath + self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir + self.version_combination = combination + self.mixdir = self.test_output_dir / "mixdir_neon" + if self.version_combination is not None: + assert ( + self.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions" + assert ( + self.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions" + self.mixdir.mkdir(mode=0o755, exist_ok=True) + self._mix_versions() def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv: # Cannot create more than one environment from one builder @@ -469,7 +430,7 @@ class NeonEnvBuilder: def init_start( self, - initial_tenant_conf: Optional[Dict[str, Any]] = None, + initial_tenant_conf: Optional[dict[str, Any]] = None, default_remote_storage_if_missing: bool = True, initial_tenant_shard_count: Optional[int] = None, initial_tenant_shard_stripe_size: Optional[int] = None, @@ -658,6 +619,21 @@ class NeonEnvBuilder: return self.env + def _mix_versions(self): + assert self.version_combination is not None, "version combination must be set" + for component, paths in COMPONENT_BINARIES.items(): + directory = ( + self.neon_binpath + if self.version_combination[component] == "new" + else self.compatibility_neon_binpath + ) + for filename in paths: + destination = self.mixdir / filename + destination.symlink_to(directory / filename) + if self.version_combination["compute"] == "old": + self.pg_distrib_dir = self.compatibility_pg_distrib_dir + self.neon_binpath = self.mixdir + def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path): """ Mount `srcdir` as an overlayfs mount at `dstdir`. @@ -824,7 +800,7 @@ class NeonEnvBuilder: overlayfs_mounts = {mountpoint for _, mountpoint in self.overlay_mounts_created_by_us} - directories_to_clean: List[Path] = [] + directories_to_clean: list[Path] = [] for test_entry in Path(self.repo_dir).glob("**/*"): if test_entry in overlayfs_mounts: continue @@ -855,12 +831,12 @@ class NeonEnvBuilder: if isinstance(x, S3Storage): x.do_cleanup() - def __enter__(self) -> "NeonEnvBuilder": + def __enter__(self) -> NeonEnvBuilder: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ): @@ -971,8 +947,8 @@ class NeonEnv: self.port_distributor = config.port_distributor self.s3_mock_server = config.mock_s3_server self.endpoints = EndpointFactory(self) - self.safekeepers: List[Safekeeper] = [] - self.pageservers: List[NeonPageserver] = [] + self.safekeepers: list[Safekeeper] = [] + self.pageservers: list[NeonPageserver] = [] self.broker = NeonBroker(self) self.pageserver_remote_storage = config.pageserver_remote_storage self.safekeepers_remote_storage = config.safekeepers_remote_storage @@ -1041,10 +1017,10 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_aux_file_policy = config.pageserver_aux_file_policy - self.pageserver_io_buffer_alignment = config.pageserver_io_buffer_alignment + self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode # Create the neon_local's `NeonLocalInitConf` - cfg: Dict[str, Any] = { + cfg: dict[str, Any] = { "default_tenant_id": str(self.initial_tenant), "broker": { "listen_addr": self.broker.listen_addr(), @@ -1073,7 +1049,7 @@ class NeonEnv: http=self.port_distributor.get_port(), ) - ps_cfg: Dict[str, Any] = { + ps_cfg: dict[str, Any] = { "id": ps_id, "listen_pg_addr": f"localhost:{pageserver_port.pg}", "listen_http_addr": f"localhost:{pageserver_port.http}", @@ -1105,7 +1081,8 @@ class NeonEnv: for key, value in override.items(): ps_cfg[key] = value - ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment + if self.pageserver_virtual_file_io_mode is not None: + ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode # Create a corresponding NeonPageserver object self.pageservers.append( @@ -1121,7 +1098,7 @@ class NeonEnv: http=self.port_distributor.get_port(), ) id = config.safekeepers_id_start + i # assign ids sequentially - sk_cfg: Dict[str, Any] = { + sk_cfg: dict[str, Any] = { "id": id, "pg_port": port.pg, "pg_tenant_only_port": port.pg_tenant_only, @@ -1286,9 +1263,8 @@ class NeonEnv: res = subprocess.run( [bin_pageserver, "--version"], check=True, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + text=True, + capture_output=True, ) return res.stdout @@ -1331,13 +1307,13 @@ class NeonEnv: self, tenant_id: Optional[TenantId] = None, timeline_id: Optional[TimelineId] = None, - conf: Optional[Dict[str, Any]] = None, + conf: Optional[dict[str, Any]] = None, shard_count: Optional[int] = None, shard_stripe_size: Optional[int] = None, placement_policy: Optional[str] = None, set_default: bool = False, aux_file_policy: Optional[AuxFileStore] = None, - ) -> Tuple[TenantId, TimelineId]: + ) -> tuple[TenantId, TimelineId]: """ Creates a new tenant, returns its id and its initial timeline's id. """ @@ -1358,7 +1334,7 @@ class NeonEnv: return tenant_id, timeline_id - def config_tenant(self, tenant_id: Optional[TenantId], conf: Dict[str, str]): + def config_tenant(self, tenant_id: Optional[TenantId], conf: dict[str, str]): """ Update tenant config. """ @@ -1406,12 +1382,14 @@ def neon_simple_env( top_output_dir: Path, test_output_dir: Path, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, pageserver_virtual_file_io_engine: str, pageserver_aux_file_policy: Optional[AuxFileStore], - pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]], - pageserver_io_buffer_alignment: Optional[int], + pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]], + pageserver_virtual_file_io_mode: Optional[str], ) -> Iterator[NeonEnv]: """ Simple Neon environment, with no authentication and no safekeepers. @@ -1421,6 +1399,11 @@ def neon_simple_env( # Create the environment in the per-test output directory repo_dir = get_test_repo_dir(request, top_output_dir) + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) with NeonEnvBuilder( top_output_dir=top_output_dir, @@ -1428,7 +1411,9 @@ def neon_simple_env( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), @@ -1437,7 +1422,8 @@ def neon_simple_env( pageserver_virtual_file_io_engine=pageserver_virtual_file_io_engine, pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, - pageserver_io_buffer_alignment=pageserver_io_buffer_alignment, + pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode, + combination=combination, ) as builder: env = builder.init_start() @@ -1451,17 +1437,19 @@ def neon_env_builder( port_distributor: PortDistributor, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, run_id: uuid.UUID, request: FixtureRequest, test_overlay_dir: Path, top_output_dir: Path, pageserver_virtual_file_io_engine: str, - pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]], + pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]], pageserver_aux_file_policy: Optional[AuxFileStore], record_property: Callable[[str, object], None], - pageserver_io_buffer_alignment: Optional[int], + pageserver_virtual_file_io_mode: Optional[str], ) -> Iterator[NeonEnvBuilder]: """ Fixture to create a Neon environment for test. @@ -1478,6 +1466,11 @@ def neon_env_builder( # Create the environment in the test-specific output dir repo_dir = os.path.join(test_output_dir, "repo") + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) # Return the builder to the caller with NeonEnvBuilder( @@ -1486,7 +1479,10 @@ def neon_env_builder( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, + combination=combination, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), @@ -1496,7 +1492,7 @@ def neon_env_builder( test_overlay_dir=test_overlay_dir, pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, - pageserver_io_buffer_alignment=pageserver_io_buffer_alignment, + pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode, ) as builder: yield builder # Propogate `preserve_database_files` to make it possible to use in other fixtures, @@ -1520,7 +1516,7 @@ class LogUtils: def assert_log_contains( self, pattern: str, offset: None | LogCursor = None - ) -> Tuple[str, LogCursor]: + ) -> tuple[str, LogCursor]: """Convenient for use inside wait_until()""" res = self.log_contains(pattern, offset=offset) @@ -1529,7 +1525,7 @@ class LogUtils: def log_contains( self, pattern: str, offset: None | LogCursor = None - ) -> Optional[Tuple[str, LogCursor]]: + ) -> Optional[tuple[str, LogCursor]]: """Check that the log contains a line that matches the given regex""" logfile = self.logfile if not logfile.exists(): @@ -1610,7 +1606,7 @@ class NeonStorageController(MetricsGetter, LogUtils): self.running = True return self - def stop(self, immediate: bool = False) -> "NeonStorageController": + def stop(self, immediate: bool = False) -> NeonStorageController: if self.running: self.env.neon_cli.storage_controller_stop(immediate) self.running = False @@ -1672,7 +1668,7 @@ class NeonStorageController(MetricsGetter, LogUtils): return resp - def headers(self, scope: Optional[TokenScope]) -> Dict[str, str]: + def headers(self, scope: Optional[TokenScope]) -> dict[str, str]: headers = {} if self.auth_enabled and scope is not None: jwt_token = self.env.auth_keys.generate_token(scope=scope) @@ -1858,13 +1854,13 @@ class NeonStorageController(MetricsGetter, LogUtils): tenant_id: TenantId, shard_count: Optional[int] = None, shard_stripe_size: Optional[int] = None, - tenant_config: Optional[Dict[Any, Any]] = None, - placement_policy: Optional[Union[Dict[Any, Any] | str]] = None, + tenant_config: Optional[dict[Any, Any]] = None, + placement_policy: Optional[Union[dict[Any, Any] | str]] = None, ): """ Use this rather than pageserver_api() when you need to include shard parameters """ - body: Dict[str, Any] = {"new_tenant_id": str(tenant_id)} + body: dict[str, Any] = {"new_tenant_id": str(tenant_id)} if shard_count is not None: shard_params = {"count": shard_count} @@ -2080,8 +2076,8 @@ class NeonStorageController(MetricsGetter, LogUtils): time.sleep(backoff) - def metadata_health_update(self, healthy: List[TenantShardId], unhealthy: List[TenantShardId]): - body: Dict[str, Any] = { + def metadata_health_update(self, healthy: list[TenantShardId], unhealthy: list[TenantShardId]): + body: dict[str, Any] = { "healthy_tenant_shards": [str(t) for t in healthy], "unhealthy_tenant_shards": [str(t) for t in unhealthy], } @@ -2102,7 +2098,7 @@ class NeonStorageController(MetricsGetter, LogUtils): return response.json() def metadata_health_list_outdated(self, duration: str): - body: Dict[str, Any] = {"not_scrubbed_for": duration} + body: dict[str, Any] = {"not_scrubbed_for": duration} response = self.request( "POST", @@ -2136,7 +2132,7 @@ class NeonStorageController(MetricsGetter, LogUtils): response.raise_for_status() return response.json() - def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]): + def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]): if isinstance(config_strings, tuple): pairs = [config_strings] else: @@ -2153,13 +2149,13 @@ class NeonStorageController(MetricsGetter, LogUtils): log.info(f"Got failpoints request response code {res.status_code}") res.raise_for_status() - def get_tenants_placement(self) -> defaultdict[str, Dict[str, Any]]: + def get_tenants_placement(self) -> defaultdict[str, dict[str, Any]]: """ Get the intent and observed placements of all tenants known to the storage controller. """ tenants = self.tenant_list() - tenant_placement: defaultdict[str, Dict[str, Any]] = defaultdict( + tenant_placement: defaultdict[str, dict[str, Any]] = defaultdict( lambda: { "observed": {"attached": None, "secondary": []}, "intent": {"attached": None, "secondary": []}, @@ -2266,12 +2262,12 @@ class NeonStorageController(MetricsGetter, LogUtils): response.raise_for_status() return [TenantShardId.parse(tid) for tid in response.json()["updated"]] - def __enter__(self) -> "NeonStorageController": + def __enter__(self) -> NeonStorageController: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -2280,7 +2276,7 @@ class NeonStorageController(MetricsGetter, LogUtils): class NeonProxiedStorageController(NeonStorageController): def __init__(self, env: NeonEnv, proxy_port: int, auth_enabled: bool): - super(NeonProxiedStorageController, self).__init__(env, proxy_port, auth_enabled) + super().__init__(env, proxy_port, auth_enabled) self.instances: dict[int, dict[str, Any]] = {} def start( @@ -2299,7 +2295,7 @@ class NeonProxiedStorageController(NeonStorageController): def stop_instance( self, immediate: bool = False, instance_id: Optional[int] = None - ) -> "NeonStorageController": + ) -> NeonStorageController: assert instance_id in self.instances if self.instances[instance_id]["running"]: self.env.neon_cli.storage_controller_stop(immediate, instance_id) @@ -2308,7 +2304,7 @@ class NeonProxiedStorageController(NeonStorageController): self.running = any(meta["running"] for meta in self.instances.values()) return self - def stop(self, immediate: bool = False) -> "NeonStorageController": + def stop(self, immediate: bool = False) -> NeonStorageController: for iid, details in self.instances.items(): if details["running"]: self.env.neon_cli.storage_controller_stop(immediate, iid) @@ -2327,7 +2323,7 @@ class NeonProxiedStorageController(NeonStorageController): def log_contains( self, pattern: str, offset: None | LogCursor = None - ) -> Optional[Tuple[str, LogCursor]]: + ) -> Optional[tuple[str, LogCursor]]: raise NotImplementedError() @@ -2359,7 +2355,7 @@ class NeonPageserver(PgProtocol, LogUtils): # env.pageserver.allowed_errors.append(".*could not open garage door.*") # # The entries in the list are regular experessions. - self.allowed_errors: List[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS) + self.allowed_errors: list[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS) def timeline_dir( self, @@ -2384,19 +2380,19 @@ class NeonPageserver(PgProtocol, LogUtils): def config_toml_path(self) -> Path: return self.workdir / "pageserver.toml" - def edit_config_toml(self, edit_fn: Callable[[Dict[str, Any]], T]) -> T: + def edit_config_toml(self, edit_fn: Callable[[dict[str, Any]], T]) -> T: """ Edit the pageserver's config toml file in place. """ path = self.config_toml_path - with open(path, "r") as f: + with open(path) as f: config = toml.load(f) res = edit_fn(config) with open(path, "w") as f: toml.dump(config, f) return res - def patch_config_toml_nonrecursive(self, patch: Dict[str, Any]) -> Dict[str, Any]: + def patch_config_toml_nonrecursive(self, patch: dict[str, Any]) -> dict[str, Any]: """ Non-recursively merge the given `patch` dict into the existing config toml, using `dict.update()`. Returns the replaced values. @@ -2405,7 +2401,7 @@ class NeonPageserver(PgProtocol, LogUtils): """ replacements = {} - def doit(config: Dict[str, Any]): + def doit(config: dict[str, Any]): while len(patch) > 0: key, new = patch.popitem() old = config.get(key, None) @@ -2417,9 +2413,9 @@ class NeonPageserver(PgProtocol, LogUtils): def start( self, - extra_env_vars: Optional[Dict[str, str]] = None, + extra_env_vars: Optional[dict[str, str]] = None, timeout_in_seconds: Optional[int] = None, - ) -> "NeonPageserver": + ) -> NeonPageserver: """ Start the page server. `overrides` allows to add some config to this pageserver start. @@ -2445,7 +2441,7 @@ class NeonPageserver(PgProtocol, LogUtils): return self - def stop(self, immediate: bool = False) -> "NeonPageserver": + def stop(self, immediate: bool = False) -> NeonPageserver: """ Stop the page server. Returns self. @@ -2493,12 +2489,12 @@ class NeonPageserver(PgProtocol, LogUtils): wait_until(20, 0.5, complete) - def __enter__(self) -> "NeonPageserver": + def __enter__(self) -> NeonPageserver: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -2545,7 +2541,7 @@ class NeonPageserver(PgProtocol, LogUtils): def tenant_attach( self, tenant_id: TenantId, - config: None | Dict[str, Any] = None, + config: None | dict[str, Any] = None, generation: Optional[int] = None, override_storage_controller_generation: bool = False, ): @@ -2584,7 +2580,7 @@ class NeonPageserver(PgProtocol, LogUtils): ) -> dict[str, Any]: path = self.tenant_dir(tenant_shard_id) / "config-v1" log.info(f"Reading location conf from {path}") - bytes = open(path, "r").read() + bytes = open(path).read() try: decoded: dict[str, Any] = toml.loads(bytes) return decoded @@ -2595,7 +2591,7 @@ class NeonPageserver(PgProtocol, LogUtils): def tenant_create( self, tenant_id: TenantId, - conf: Optional[Dict[str, Any]] = None, + conf: Optional[dict[str, Any]] = None, auth_token: Optional[str] = None, generation: Optional[int] = None, ) -> TenantId: @@ -2661,7 +2657,7 @@ class PgBin: self.env = os.environ.copy() self.env["LD_LIBRARY_PATH"] = str(self.pg_lib_dir) - def _fixpath(self, command: List[str]): + def _fixpath(self, command: list[str]): if "/" not in str(command[0]): command[0] = str(self.pg_bin_path / command[0]) @@ -2681,7 +2677,7 @@ class PgBin: def run_nonblocking( self, - command: List[str], + command: list[str], env: Optional[Env] = None, cwd: Optional[Union[str, Path]] = None, ) -> subprocess.Popen[Any]: @@ -2705,7 +2701,7 @@ class PgBin: def run( self, - command: List[str], + command: list[str], env: Optional[Env] = None, cwd: Optional[Union[str, Path]] = None, ) -> None: @@ -2728,7 +2724,7 @@ class PgBin: def run_capture( self, - command: List[str], + command: list[str], env: Optional[Env] = None, cwd: Optional[str] = None, with_command_header=True, @@ -2841,14 +2837,14 @@ class VanillaPostgres(PgProtocol): ] ) - def configure(self, options: List[str]): + def configure(self, options: list[str]): """Append lines into postgresql.conf file.""" assert not self.running with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file: conf_file.write("\n".join(options)) conf_file.write("\n") - def edit_hba(self, hba: List[str]): + def edit_hba(self, hba: list[str]): """Prepend hba lines into pg_hba.conf file.""" assert not self.running with open(os.path.join(self.pgdatadir, "pg_hba.conf"), "r+") as conf_file: @@ -2876,12 +2872,12 @@ class VanillaPostgres(PgProtocol): """Return size of pgdatadir subdirectory in bytes.""" return get_dir_size(self.pgdatadir / subdir) - def __enter__(self) -> "VanillaPostgres": + def __enter__(self) -> VanillaPostgres: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -2911,7 +2907,7 @@ class RemotePostgres(PgProtocol): # The remote server is assumed to be running already self.running = True - def configure(self, options: List[str]): + def configure(self, options: list[str]): raise Exception("cannot change configuration of remote Posgres instance") def start(self): @@ -2925,12 +2921,12 @@ class RemotePostgres(PgProtocol): # See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE raise Exception("cannot get size of a Postgres instance") - def __enter__(self) -> "RemotePostgres": + def __enter__(self) -> RemotePostgres: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -3266,7 +3262,7 @@ class NeonProxy(PgProtocol): def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -3404,7 +3400,7 @@ class Endpoint(PgProtocol, LogUtils): self.http_port = http_port self.check_stop_result = check_stop_result # passed to endpoint create and endpoint reconfigure - self.active_safekeepers: List[int] = list(map(lambda sk: sk.id, env.safekeepers)) + self.active_safekeepers: list[int] = list(map(lambda sk: sk.id, env.safekeepers)) # path to conf is /endpoints//pgdata/postgresql.conf # Semaphore is set to 1 when we start, and acquire'd back to zero when we stop @@ -3427,10 +3423,10 @@ class Endpoint(PgProtocol, LogUtils): endpoint_id: Optional[str] = None, hot_standby: bool = False, lsn: Optional[Lsn] = None, - config_lines: Optional[List[str]] = None, + config_lines: Optional[list[str]] = None, pageserver_id: Optional[int] = None, allow_multiple: bool = False, - ) -> "Endpoint": + ) -> Endpoint: """ Create a new Postgres endpoint. Returns self. @@ -3473,10 +3469,10 @@ class Endpoint(PgProtocol, LogUtils): self, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, - safekeepers: Optional[List[int]] = None, + safekeepers: Optional[list[int]] = None, allow_multiple: bool = False, basebackup_request_tries: Optional[int] = None, - ) -> "Endpoint": + ) -> Endpoint: """ Start the Postgres instance. Returns self. @@ -3489,8 +3485,6 @@ class Endpoint(PgProtocol, LogUtils): if safekeepers is not None: self.active_safekeepers = safekeepers - log.info(f"Starting postgres endpoint {self.endpoint_id}") - self.env.neon_cli.endpoint_start( self.endpoint_id, safekeepers=self.active_safekeepers, @@ -3525,7 +3519,7 @@ class Endpoint(PgProtocol, LogUtils): """Path to the postgresql.conf in the endpoint directory (not the one in pgdata)""" return self.endpoint_path() / "postgresql.conf" - def config(self, lines: List[str]) -> "Endpoint": + def config(self, lines: list[str]) -> Endpoint: """ Add lines to postgresql.conf. Lines should be an array of valid postgresql.conf rows. @@ -3539,7 +3533,7 @@ class Endpoint(PgProtocol, LogUtils): return self - def edit_hba(self, hba: List[str]): + def edit_hba(self, hba: list[str]): """Prepend hba lines into pg_hba.conf file.""" with open(os.path.join(self.pg_data_dir_path(), "pg_hba.conf"), "r+") as conf_file: data = conf_file.read() @@ -3554,7 +3548,7 @@ class Endpoint(PgProtocol, LogUtils): return self._running._value > 0 def reconfigure( - self, pageserver_id: Optional[int] = None, safekeepers: Optional[List[int]] = None + self, pageserver_id: Optional[int] = None, safekeepers: Optional[list[int]] = None ): assert self.endpoint_id is not None # If `safekeepers` is not None, they are remember them as active and use @@ -3569,7 +3563,7 @@ class Endpoint(PgProtocol, LogUtils): """Update the endpoint.json file used by control_plane.""" # Read config config_path = os.path.join(self.endpoint_path(), "endpoint.json") - with open(config_path, "r") as f: + with open(config_path) as f: data_dict: dict[str, Any] = json.load(f) # Write it back updated @@ -3602,8 +3596,8 @@ class Endpoint(PgProtocol, LogUtils): def stop( self, mode: str = "fast", - sks_wait_walreceiver_gone: Optional[tuple[List[Safekeeper], TimelineId]] = None, - ) -> "Endpoint": + sks_wait_walreceiver_gone: Optional[tuple[list[Safekeeper], TimelineId]] = None, + ) -> Endpoint: """ Stop the Postgres instance if it's running. @@ -3637,7 +3631,7 @@ class Endpoint(PgProtocol, LogUtils): return self - def stop_and_destroy(self, mode: str = "immediate") -> "Endpoint": + def stop_and_destroy(self, mode: str = "immediate") -> Endpoint: """ Stop the Postgres instance, then destroy the endpoint. Returns self. @@ -3659,19 +3653,17 @@ class Endpoint(PgProtocol, LogUtils): endpoint_id: Optional[str] = None, hot_standby: bool = False, lsn: Optional[Lsn] = None, - config_lines: Optional[List[str]] = None, + config_lines: Optional[list[str]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, - allow_multiple=False, + allow_multiple: bool = False, basebackup_request_tries: Optional[int] = None, - ) -> "Endpoint": + ) -> Endpoint: """ Create an endpoint, apply config, and start Postgres. Returns self. """ - started_at = time.time() - self.create( branch_name=branch_name, endpoint_id=endpoint_id, @@ -3687,16 +3679,14 @@ class Endpoint(PgProtocol, LogUtils): basebackup_request_tries=basebackup_request_tries, ) - log.info(f"Postgres startup took {time.time() - started_at} seconds") - return self - def __enter__(self) -> "Endpoint": + def __enter__(self) -> Endpoint: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): @@ -3727,7 +3717,7 @@ class EndpointFactory: def __init__(self, env: NeonEnv): self.env = env self.num_instances: int = 0 - self.endpoints: List[Endpoint] = [] + self.endpoints: list[Endpoint] = [] def create_start( self, @@ -3736,7 +3726,7 @@ class EndpointFactory: tenant_id: Optional[TenantId] = None, lsn: Optional[Lsn] = None, hot_standby: bool = False, - config_lines: Optional[List[str]] = None, + config_lines: Optional[list[str]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, basebackup_request_tries: Optional[int] = None, @@ -3768,7 +3758,7 @@ class EndpointFactory: tenant_id: Optional[TenantId] = None, lsn: Optional[Lsn] = None, hot_standby: bool = False, - config_lines: Optional[List[str]] = None, + config_lines: Optional[list[str]] = None, pageserver_id: Optional[int] = None, ) -> Endpoint: ep = Endpoint( @@ -3792,7 +3782,7 @@ class EndpointFactory: pageserver_id=pageserver_id, ) - def stop_all(self, fail_on_error=True) -> "EndpointFactory": + def stop_all(self, fail_on_error=True) -> EndpointFactory: exception = None for ep in self.endpoints: try: @@ -3807,7 +3797,7 @@ class EndpointFactory: return self def new_replica( - self, origin: Endpoint, endpoint_id: str, config_lines: Optional[List[str]] = None + self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None ): branch_name = origin.branch_name assert origin in self.endpoints @@ -3823,7 +3813,7 @@ class EndpointFactory: ) def new_replica_start( - self, origin: Endpoint, endpoint_id: str, config_lines: Optional[List[str]] = None + self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None ): branch_name = origin.branch_name assert origin in self.endpoints @@ -3861,7 +3851,7 @@ class Safekeeper(LogUtils): port: SafekeeperPort, id: int, running: bool = False, - extra_opts: Optional[List[str]] = None, + extra_opts: Optional[list[str]] = None, ): self.env = env self.port = port @@ -3887,8 +3877,8 @@ class Safekeeper(LogUtils): self.extra_opts = extra_opts def start( - self, extra_opts: Optional[List[str]] = None, timeout_in_seconds: Optional[int] = None - ) -> "Safekeeper": + self, extra_opts: Optional[list[str]] = None, timeout_in_seconds: Optional[int] = None + ) -> Safekeeper: if extra_opts is None: # Apply either the extra_opts passed in, or the ones from our constructor: we do not merge the two. extra_opts = self.extra_opts @@ -3923,8 +3913,7 @@ class Safekeeper(LogUtils): break # success return self - def stop(self, immediate: bool = False) -> "Safekeeper": - log.info(f"Stopping safekeeper {self.id}") + def stop(self, immediate: bool = False) -> Safekeeper: self.env.neon_cli.safekeeper_stop(self.id, immediate) self.running = False return self @@ -3935,8 +3924,8 @@ class Safekeeper(LogUtils): assert not self.log_contains("timeout while acquiring WalResidentTimeline guard") def append_logical_message( - self, tenant_id: TenantId, timeline_id: TimelineId, request: Dict[str, Any] - ) -> Dict[str, Any]: + self, tenant_id: TenantId, timeline_id: TimelineId, request: dict[str, Any] + ) -> dict[str, Any]: """ Send JSON_CTRL query to append LogicalMessage to WAL and modify safekeeper state. It will construct LogicalMessage from provided @@ -3989,7 +3978,7 @@ class Safekeeper(LogUtils): def pull_timeline( self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ pull_timeline from srcs to self. """ @@ -4008,7 +3997,7 @@ class Safekeeper(LogUtils): def timeline_dir(self, tenant_id, timeline_id) -> Path: return self.data_dir / str(tenant_id) / str(timeline_id) - # List partial uploaded segments of this safekeeper. Works only for + # list partial uploaded segments of this safekeeper. Works only for # RemoteStorageKind.LOCAL_FS. def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId): tline_path = ( @@ -4025,7 +4014,7 @@ class Safekeeper(LogUtils): mysegs = [s for s in segs if f"sk{self.id}" in s] return mysegs - def list_segments(self, tenant_id, timeline_id) -> List[str]: + def list_segments(self, tenant_id, timeline_id) -> list[str]: """ Get list of segment names of the given timeline. """ @@ -4130,7 +4119,7 @@ class StorageScrubber: self.log_dir = log_dir def scrubber_cli( - self, args: list[str], timeout, extra_env: Optional[Dict[str, str]] = None + self, args: list[str], timeout, extra_env: Optional[dict[str, str]] = None ) -> str: assert isinstance(self.env.pageserver_remote_storage, S3Storage) s3_storage = self.env.pageserver_remote_storage @@ -4177,10 +4166,10 @@ class StorageScrubber: def scan_metadata_safekeeper( self, - timeline_lsns: List[Dict[str, Any]], + timeline_lsns: list[dict[str, Any]], cloud_admin_api_url: str, cloud_admin_api_token: str, - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: extra_env = { "CLOUD_ADMIN_API_URL": cloud_admin_api_url, "CLOUD_ADMIN_API_TOKEN": cloud_admin_api_token, @@ -4193,9 +4182,9 @@ class StorageScrubber: self, post_to_storage_controller: bool = False, node_kind: NodeKind = NodeKind.PAGESERVER, - timeline_lsns: Optional[List[Dict[str, Any]]] = None, - extra_env: Optional[Dict[str, str]] = None, - ) -> Tuple[bool, Any]: + timeline_lsns: Optional[list[dict[str, Any]]] = None, + extra_env: Optional[dict[str, str]] = None, + ) -> tuple[bool, Any]: """ Returns the health status and the metadata summary. """ @@ -4256,44 +4245,6 @@ class StorageScrubber: raise -def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path: - """Compute the path to a working directory for an individual test.""" - test_name = request.node.name - test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}" - - # We rerun flaky tests multiple times, use a separate directory for each run. - if (suffix := getattr(request.node, "execution_count", None)) is not None: - test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" - - log.info(f"get_test_output_dir is {test_dir}") - # make mypy happy - assert isinstance(test_dir, Path) - return test_dir - - -def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - The working directory for a test. - """ - return _get_test_dir(request, top_output_dir, "") - - -def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - Directory that contains `upperdir` and `workdir` for overlayfs mounts - that a test creates. See `NeonEnvBuilder.overlay_mount`. - """ - return _get_test_dir(request, top_output_dir, "overlay-") - - -def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: - return top_output_dir / "shared-snapshots" / snapshot_name - - -def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - return get_test_output_dir(request, top_output_dir) / "repo" - - def pytest_addoption(parser: Parser): parser.addoption( "--preserve-database-files", @@ -4303,154 +4254,11 @@ def pytest_addoption(parser: Parser): ) -SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile( r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)" ) -# This is autouse, so the test output directory always gets created, even -# if a test doesn't put anything there. -# -# NB: we request the overlay dir fixture so the fixture does its cleanups -@pytest.fixture(scope="function", autouse=True) -def test_output_dir( - request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path -) -> Iterator[Path]: - """Create the working directory for an individual test.""" - - # one directory per test - test_dir = get_test_output_dir(request, top_output_dir) - log.info(f"test_output_dir is {test_dir}") - shutil.rmtree(test_dir, ignore_errors=True) - test_dir.mkdir() - - yield test_dir - - # Allure artifacts creation might involve the creation of `.tar.zst` archives, - # which aren't going to be used if Allure results collection is not enabled - # (i.e. --alluredir is not set). - # Skip `allure_attach_from_dir` in this case - if not request.config.getoption("--alluredir"): - return - - preserve_database_files = False - for k, v in request.node.user_properties: - # NB: the neon_env_builder fixture uses this fixture (test_output_dir). - # So, neon_env_builder's cleanup runs before here. - # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. - if k == "preserve_database_files": - assert isinstance(v, bool) - preserve_database_files = v - - allure_attach_from_dir(test_dir, preserve_database_files) - - -class FileAndThreadLock: - def __init__(self, path: Path): - self.path = path - self.thread_lock = threading.Lock() - self.fd: Optional[int] = None - - def __enter__(self): - self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) - # lock thread lock before file lock so that there's no race - # around flocking / funlocking the file lock - self.thread_lock.acquire() - flock(self.fd, LOCK_EX) - - def __exit__(self, exc_type, exc_value, exc_traceback): - assert self.fd is not None - assert self.thread_lock.locked() # ... by us - flock(self.fd, LOCK_UN) - self.thread_lock.release() - os.close(self.fd) - self.fd = None - - -class SnapshotDirLocked: - def __init__(self, parent: SnapshotDir): - self._parent = parent - - def is_initialized(self): - # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. - # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. - return self._parent._marker_file_path.exists() - - def set_initialized(self): - self._parent._marker_file_path.write_text("") - - @property - def path(self) -> Path: - return self._parent._path / "snapshot" - - -class SnapshotDir: - _path: Path - - def __init__(self, path: Path): - self._path = path - assert self._path.is_dir() - self._lock = FileAndThreadLock(self._lock_file_path) - - @property - def _lock_file_path(self) -> Path: - return self._path / "initializing.flock" - - @property - def _marker_file_path(self) -> Path: - return self._path / "initialized.marker" - - def __enter__(self) -> SnapshotDirLocked: - self._lock.__enter__() - return SnapshotDirLocked(self) - - def __exit__(self, exc_type, exc_value, exc_traceback): - self._lock.__exit__(exc_type, exc_value, exc_traceback) - - -def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir: - snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) - snapshot_dir_path.mkdir(exist_ok=True, parents=True) - return SnapshotDir(snapshot_dir_path) - - -@pytest.fixture(scope="function") -def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: - """ - Idempotently create a test's overlayfs mount state directory. - If the functionality isn't enabled via env var, returns None. - - The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). - """ - - if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: - return None - - overlay_dir = get_test_overlay_dir(request, top_output_dir) - log.info(f"test_overlay_dir is {overlay_dir}") - - overlay_dir.mkdir(exist_ok=True) - # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` - for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): - cmd = ["sudo", "umount", str(mountpoint)] - log.info( - f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" - ) - subprocess.run(cmd, capture_output=True, check=True) - # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. - cmd = ["sudo", "rm", "-rf", str(overlay_dir)] - subprocess.run(cmd, capture_output=True, check=True) - - overlay_dir.mkdir() - - return overlay_dir - - # no need to clean up anything: on clean shutdown, - # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup - # and on unclean shutdown, this function will take care of it - # on the next test run - - SKIP_DIRS = frozenset( ( "pg_wal", @@ -4502,7 +4310,7 @@ def should_skip_file(filename: str) -> bool: # # Test helpers # -def list_files_to_compare(pgdata_dir: Path) -> List[str]: +def list_files_to_compare(pgdata_dir: Path) -> list[str]: pgdata_files = [] for root, _dirs, filenames in os.walk(pgdata_dir): for filename in filenames: diff --git a/test_runner/fixtures/overlayfs.py b/test_runner/fixtures/overlayfs.py index 3e2f661893..ea11cd272c 100644 --- a/test_runner/fixtures/overlayfs.py +++ b/test_runner/fixtures/overlayfs.py @@ -1,8 +1,13 @@ +from __future__ import annotations + from pathlib import Path -from typing import Iterator +from typing import TYPE_CHECKING import psutil +if TYPE_CHECKING: + from collections.abc import Iterator + def iter_mounts_beneath(topdir: Path) -> Iterator[Path]: """ diff --git a/test_runner/fixtures/pageserver/__init__.py b/test_runner/fixtures/pageserver/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/fixtures/pageserver/__init__.py +++ b/test_runner/fixtures/pageserver/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/fixtures/pageserver/allowed_errors.py b/test_runner/fixtures/pageserver/allowed_errors.py index f8d9a51c91..fa85563e35 100755 --- a/test_runner/fixtures/pageserver/allowed_errors.py +++ b/test_runner/fixtures/pageserver/allowed_errors.py @@ -1,14 +1,16 @@ #! /usr/bin/env python3 +from __future__ import annotations + import argparse import re import sys -from typing import Iterable, List, Tuple +from collections.abc import Iterable def scan_pageserver_log_for_errors( - input: Iterable[str], allowed_errors: List[str] -) -> List[Tuple[int, str]]: + input: Iterable[str], allowed_errors: list[str] +) -> list[tuple[int, str]]: error_or_warn = re.compile(r"\s(ERROR|WARN)") errors = [] for lineno, line in enumerate(input, start=1): @@ -113,7 +115,7 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [ def _check_allowed_errors(input): - allowed_errors: List[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS) + allowed_errors: list[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS) # add any test specifics here; cli parsing is not provided for the # difficulty of copypasting regexes as arguments without any quoting diff --git a/test_runner/fixtures/pageserver/common_types.py b/test_runner/fixtures/pageserver/common_types.py index a6c327a8a0..2319701e0b 100644 --- a/test_runner/fixtures/pageserver/common_types.py +++ b/test_runner/fixtures/pageserver/common_types.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import re from dataclasses import dataclass -from typing import Any, Dict, Tuple, Union +from typing import TYPE_CHECKING, Union from fixtures.common_types import KEY_MAX, KEY_MIN, Key, Lsn +if TYPE_CHECKING: + from typing import Any + @dataclass class IndexLayerMetadata: @@ -53,7 +58,7 @@ IMAGE_LAYER_FILE_NAME = re.compile( ) -def parse_image_layer(f_name: str) -> Tuple[int, int, int]: +def parse_image_layer(f_name: str) -> tuple[int, int, int]: """Parse an image layer file name. Return key start, key end, and snapshot lsn""" match = IMAGE_LAYER_FILE_NAME.match(f_name) @@ -68,7 +73,7 @@ DELTA_LAYER_FILE_NAME = re.compile( ) -def parse_delta_layer(f_name: str) -> Tuple[int, int, int, int]: +def parse_delta_layer(f_name: str) -> tuple[int, int, int, int]: """Parse a delta layer file name. Return key start, key end, lsn start, and lsn end""" match = DELTA_LAYER_FILE_NAME.match(f_name) if match is None: @@ -121,11 +126,11 @@ def is_future_layer(layer_file_name: LayerName, disk_consistent_lsn: Lsn): @dataclass class IndexPartDump: - layer_metadata: Dict[LayerName, IndexLayerMetadata] + layer_metadata: dict[LayerName, IndexLayerMetadata] disk_consistent_lsn: Lsn @classmethod - def from_json(cls, d: Dict[str, Any]) -> "IndexPartDump": + def from_json(cls, d: dict[str, Any]) -> IndexPartDump: return IndexPartDump( layer_metadata={ parse_layer_file_name(n): IndexLayerMetadata(v["file_size"], v["generation"]) diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 49ad54d456..aa4435af4e 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any import requests from requests.adapters import HTTPAdapter @@ -16,6 +16,9 @@ from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.pg_version import PgVersion from fixtures.utils import Fn +if TYPE_CHECKING: + from typing import Optional, Union + class PageserverApiException(Exception): def __init__(self, message, status_code: int): @@ -43,7 +46,7 @@ class InMemoryLayerInfo: lsn_end: Optional[str] @classmethod - def from_json(cls, d: Dict[str, Any]) -> InMemoryLayerInfo: + def from_json(cls, d: dict[str, Any]) -> InMemoryLayerInfo: return InMemoryLayerInfo( kind=d["kind"], lsn_start=d["lsn_start"], @@ -64,7 +67,7 @@ class HistoricLayerInfo: visible: bool @classmethod - def from_json(cls, d: Dict[str, Any]) -> HistoricLayerInfo: + def from_json(cls, d: dict[str, Any]) -> HistoricLayerInfo: # instead of parsing the key range lets keep the definition of "L0" in pageserver l0_ness = d.get("l0") assert l0_ness is None or isinstance(l0_ness, bool) @@ -86,53 +89,53 @@ class HistoricLayerInfo: @dataclass class LayerMapInfo: - in_memory_layers: List[InMemoryLayerInfo] - historic_layers: List[HistoricLayerInfo] + in_memory_layers: list[InMemoryLayerInfo] + historic_layers: list[HistoricLayerInfo] @classmethod - def from_json(cls, d: Dict[str, Any]) -> LayerMapInfo: + def from_json(cls, d: dict[str, Any]) -> LayerMapInfo: info = LayerMapInfo(in_memory_layers=[], historic_layers=[]) json_in_memory_layers = d["in_memory_layers"] - assert isinstance(json_in_memory_layers, List) + assert isinstance(json_in_memory_layers, list) for json_in_memory_layer in json_in_memory_layers: info.in_memory_layers.append(InMemoryLayerInfo.from_json(json_in_memory_layer)) json_historic_layers = d["historic_layers"] - assert isinstance(json_historic_layers, List) + assert isinstance(json_historic_layers, list) for json_historic_layer in json_historic_layers: info.historic_layers.append(HistoricLayerInfo.from_json(json_historic_layer)) return info - def kind_count(self) -> Dict[str, int]: - counts: Dict[str, int] = defaultdict(int) + def kind_count(self) -> dict[str, int]: + counts: dict[str, int] = defaultdict(int) for inmem_layer in self.in_memory_layers: counts[inmem_layer.kind] += 1 for hist_layer in self.historic_layers: counts[hist_layer.kind] += 1 return counts - def delta_layers(self) -> List[HistoricLayerInfo]: + def delta_layers(self) -> list[HistoricLayerInfo]: return [x for x in self.historic_layers if x.kind == "Delta"] - def image_layers(self) -> List[HistoricLayerInfo]: + def image_layers(self) -> list[HistoricLayerInfo]: return [x for x in self.historic_layers if x.kind == "Image"] - def delta_l0_layers(self) -> List[HistoricLayerInfo]: + def delta_l0_layers(self) -> list[HistoricLayerInfo]: return [x for x in self.historic_layers if x.kind == "Delta" and x.l0] - def historic_by_name(self) -> Set[str]: + def historic_by_name(self) -> set[str]: return set(x.layer_file_name for x in self.historic_layers) @dataclass class TenantConfig: - tenant_specific_overrides: Dict[str, Any] - effective_config: Dict[str, Any] + tenant_specific_overrides: dict[str, Any] + effective_config: dict[str, Any] @classmethod - def from_json(cls, d: Dict[str, Any]) -> TenantConfig: + def from_json(cls, d: dict[str, Any]) -> TenantConfig: return TenantConfig( tenant_specific_overrides=d["tenant_specific_overrides"], effective_config=d["effective_config"], @@ -209,7 +212,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def check_status(self): self.get(f"http://localhost:{self.port}/v1/status").raise_for_status() - def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]): + def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]): self.is_testing_enabled_or_skip() if isinstance(config_strings, tuple): @@ -233,7 +236,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): res = self.post(f"http://localhost:{self.port}/v1/reload_auth_validation_keys") self.verbose_error(res) - def tenant_list(self) -> List[Dict[Any, Any]]: + def tenant_list(self) -> list[dict[Any, Any]]: res = self.get(f"http://localhost:{self.port}/v1/tenant") self.verbose_error(res) res_json = res.json() @@ -244,7 +247,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self, tenant_id: Union[TenantId, TenantShardId], generation: int, - config: None | Dict[str, Any] = None, + config: None | dict[str, Any] = None, ): config = config or {} @@ -324,7 +327,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def tenant_status( self, tenant_id: Union[TenantId, TenantShardId], activate: bool = False - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: """ :activate: hint the server not to accelerate activation of this tenant in response to this query. False by default for tests, because they generally want to observed the @@ -378,8 +381,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def patch_tenant_config_client_side( self, tenant_id: TenantId, - inserts: Optional[Dict[str, Any]] = None, - removes: Optional[List[str]] = None, + inserts: Optional[dict[str, Any]] = None, + removes: Optional[list[str]] = None, ): current = self.tenant_config(tenant_id).tenant_specific_overrides if inserts is not None: @@ -394,7 +397,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def tenant_size_and_modelinputs( self, tenant_id: Union[TenantId, TenantShardId] - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: """ Returns the tenant size, together with the model inputs as the second tuple item. """ @@ -424,7 +427,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): tenant_id: Union[TenantId, TenantShardId], timestamp: datetime, done_if_after: datetime, - shard_counts: Optional[List[int]] = None, + shard_counts: Optional[list[int]] = None, ): """ Issues a request to perform time travel operations on the remote storage @@ -432,7 +435,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): if shard_counts is None: shard_counts = [] - body: Dict[str, Any] = { + body: dict[str, Any] = { "shard_counts": shard_counts, } res = self.put( @@ -446,7 +449,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): tenant_id: Union[TenantId, TenantShardId], include_non_incremental_logical_size: bool = False, include_timeline_dir_layer_file_size_sum: bool = False, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: params = {} if include_non_incremental_logical_size: params["include-non-incremental-logical-size"] = "true" @@ -470,8 +473,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): ancestor_start_lsn: Optional[Lsn] = None, existing_initdb_timeline_id: Optional[TimelineId] = None, **kwargs, - ) -> Dict[Any, Any]: - body: Dict[str, Any] = { + ) -> dict[Any, Any]: + body: dict[str, Any] = { "new_timeline_id": str(new_timeline_id), "ancestor_start_lsn": str(ancestor_start_lsn) if ancestor_start_lsn else None, "ancestor_timeline_id": str(ancestor_timeline_id) if ancestor_timeline_id else None, @@ -504,7 +507,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): include_timeline_dir_layer_file_size_sum: bool = False, force_await_initial_logical_size: bool = False, **kwargs, - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: params = {} if include_non_incremental_logical_size: params["include-non-incremental-logical-size"] = "true" @@ -844,7 +847,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): ) if len(res) != 2: return None - inc, dec = [res[metric] for metric in metrics] + inc, dec = (res[metric] for metric in metrics) queue_count = int(inc) - int(dec) assert queue_count >= 0 return queue_count @@ -883,9 +886,9 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, - batch_size: int | None = None, + batch_size: Optional[int] = None, **kwargs, - ) -> Set[TimelineId]: + ) -> set[TimelineId]: params = {} if batch_size is not None: params["batch_size"] = batch_size diff --git a/test_runner/fixtures/pageserver/many_tenants.py b/test_runner/fixtures/pageserver/many_tenants.py index 97e63ed4ba..37b4246d40 100644 --- a/test_runner/fixtures/pageserver/many_tenants.py +++ b/test_runner/fixtures/pageserver/many_tenants.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import concurrent.futures -from typing import Any, Callable, Dict, Tuple +from typing import TYPE_CHECKING import fixtures.pageserver.remote_storage from fixtures.common_types import TenantId, TimelineId @@ -10,10 +12,13 @@ from fixtures.neon_fixtures import ( ) from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind +if TYPE_CHECKING: + from typing import Any, Callable + def single_timeline( neon_env_builder: NeonEnvBuilder, - setup_template: Callable[[NeonEnv], Tuple[TenantId, TimelineId, Dict[str, Any]]], + setup_template: Callable[[NeonEnv], tuple[TenantId, TimelineId, dict[str, Any]]], ncopies: int, ) -> NeonEnv: """ diff --git a/test_runner/fixtures/pageserver/remote_storage.py b/test_runner/fixtures/pageserver/remote_storage.py index bc54fc4c8d..54acb9ce50 100644 --- a/test_runner/fixtures/pageserver/remote_storage.py +++ b/test_runner/fixtures/pageserver/remote_storage.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import concurrent.futures import os import queue import shutil import threading from pathlib import Path -from typing import Any, List, Tuple +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.neon_fixtures import NeonEnv @@ -14,6 +16,9 @@ from fixtures.pageserver.common_types import ( ) from fixtures.remote_storage import LocalFsStorage +if TYPE_CHECKING: + from typing import Any + def duplicate_one_tenant(env: NeonEnv, template_tenant: TenantId, new_tenant: TenantId): remote_storage = env.pageserver_remote_storage @@ -50,13 +55,13 @@ def duplicate_one_tenant(env: NeonEnv, template_tenant: TenantId, new_tenant: Te return None -def duplicate_tenant(env: NeonEnv, template_tenant: TenantId, ncopies: int) -> List[TenantId]: +def duplicate_tenant(env: NeonEnv, template_tenant: TenantId, ncopies: int) -> list[TenantId]: assert isinstance(env.pageserver_remote_storage, LocalFsStorage) def work(tenant_id): duplicate_one_tenant(env, template_tenant, tenant_id) - new_tenants: List[TenantId] = [TenantId.generate() for _ in range(0, ncopies)] + new_tenants: list[TenantId] = [TenantId.generate() for _ in range(0, ncopies)] with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: executor.map(work, new_tenants) return new_tenants @@ -79,7 +84,7 @@ def local_layer_name_from_remote_name(remote_name: str) -> str: def copy_all_remote_layer_files_to_local_tenant_dir( - env: NeonEnv, tenant_timelines: List[Tuple[TenantId, TimelineId]] + env: NeonEnv, tenant_timelines: list[tuple[TenantId, TimelineId]] ): remote_storage = env.pageserver_remote_storage assert isinstance(remote_storage, LocalFsStorage) diff --git a/test_runner/fixtures/pageserver/utils.py b/test_runner/fixtures/pageserver/utils.py index a74fef6a60..377a95fbeb 100644 --- a/test_runner/fixtures/pageserver/utils.py +++ b/test_runner/fixtures/pageserver/utils.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING from mypy_boto3_s3.type_defs import ( DeleteObjectOutputTypeDef, @@ -14,6 +16,9 @@ from fixtures.pageserver.http import PageserverApiException, PageserverHttpClien from fixtures.remote_storage import RemoteStorage, RemoteStorageKind, S3Storage from fixtures.utils import wait_until +if TYPE_CHECKING: + from typing import Any, Optional, Union + def assert_tenant_state( pageserver_http: PageserverHttpClient, @@ -66,7 +71,7 @@ def wait_for_upload( ) -def _tenant_in_expected_state(tenant_info: Dict[str, Any], expected_state: str): +def _tenant_in_expected_state(tenant_info: dict[str, Any], expected_state: str): if tenant_info["state"]["slug"] == expected_state: return True if tenant_info["state"]["slug"] == "Broken": @@ -80,7 +85,7 @@ def wait_until_tenant_state( expected_state: str, iterations: int, period: float = 1.0, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Does not use `wait_until` for debugging purposes """ @@ -136,7 +141,7 @@ def wait_until_timeline_state( expected_state: str, iterations: int, period: float = 1.0, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Does not use `wait_until` for debugging purposes """ @@ -147,7 +152,7 @@ def wait_until_timeline_state( if isinstance(timeline["state"], str): if timeline["state"] == expected_state: return timeline - elif isinstance(timeline, Dict): + elif isinstance(timeline, dict): if timeline["state"].get(expected_state): return timeline @@ -235,7 +240,7 @@ def wait_for_upload_queue_empty( # this is `started left join finished`; if match, subtracting start from finished, resulting in queue depth remaining_labels = ["shard_id", "file_kind", "op_kind"] - tl: List[Tuple[Any, float]] = [] + tl: list[tuple[Any, float]] = [] for s in started: found = False for f in finished: @@ -302,7 +307,7 @@ def assert_prefix_empty( assert remote_storage is not None response = list_prefix(remote_storage, prefix) keys = response["KeyCount"] - objects: List[ObjectTypeDef] = response.get("Contents", []) + objects: list[ObjectTypeDef] = response.get("Contents", []) common_prefixes = response.get("CommonPrefixes", []) is_mock_s3 = isinstance(remote_storage, S3Storage) and not remote_storage.cleanup @@ -430,7 +435,7 @@ def enable_remote_storage_versioning( return response -def many_small_layers_tenant_config() -> Dict[str, Any]: +def many_small_layers_tenant_config() -> dict[str, Any]: """ Create a new dict to avoid issues with deleting from the global value. In python, the global is mutable. diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 2c8e71526c..4114c2fcb3 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import os -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING import allure import pytest @@ -7,7 +9,16 @@ import toml from _pytest.python import Metafunc from fixtures.pg_version import PgVersion -from fixtures.utils import AuxFileStore + +if TYPE_CHECKING: + from typing import Any, Optional + + from fixtures.utils import AuxFileStore + + +if TYPE_CHECKING: + from typing import Any, Optional + """ Dynamically parametrize tests by different parameters @@ -35,8 +46,8 @@ def pageserver_virtual_file_io_engine() -> Optional[str]: @pytest.fixture(scope="function", autouse=True) -def pageserver_io_buffer_alignment() -> Optional[int]: - return None +def pageserver_virtual_file_io_mode() -> Optional[str]: + return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_MODE") @pytest.fixture(scope="function", autouse=True) @@ -44,7 +55,7 @@ def pageserver_aux_file_policy() -> Optional[AuxFileStore]: return None -def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict[str, Any]]: +def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM") if toml_table is None: return None @@ -54,7 +65,7 @@ def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict @pytest.fixture(scope="function", autouse=True) -def pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict[str, Any]]: +def pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: return get_pageserver_default_tenant_config_compaction_algorithm() diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py new file mode 100644 index 0000000000..65f8e432b0 --- /dev/null +++ b/test_runner/fixtures/paths.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import threading +from fcntl import LOCK_EX, LOCK_UN, flock +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING + +import pytest +from pytest import FixtureRequest + +from fixtures import overlayfs +from fixtures.log_helper import log +from fixtures.utils import allure_attach_from_dir + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Optional + + +DEFAULT_OUTPUT_DIR: str = "test_output" + + +def get_test_dir( + request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None +) -> Path: + """Compute the path to a working directory for an individual test.""" + test_name = request.node.name + test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}" + + # We rerun flaky tests multiple times, use a separate directory for each run. + if (suffix := getattr(request.node, "execution_count", None)) is not None: + test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" + + return test_dir + + +def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + The working directory for a test. + """ + return get_test_dir(request, top_output_dir) + + +def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + Directory that contains `upperdir` and `workdir` for overlayfs mounts + that a test creates. See `NeonEnvBuilder.overlay_mount`. + """ + return get_test_dir(request, top_output_dir, "overlay-") + + +def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: + return top_output_dir / "shared-snapshots" / snapshot_name + + +def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + return get_test_output_dir(request, top_output_dir) / "repo" + + +@pytest.fixture(scope="session") +def base_dir() -> Iterator[Path]: + # find the base directory (currently this is the git root) + base_dir = Path(__file__).parents[2] + log.info(f"base_dir is {base_dir}") + + yield base_dir + + +@pytest.fixture(scope="session") +def compute_config_dir(base_dir: Path) -> Iterator[Path]: + """ + Retrieve the path to the compute configuration directory. + """ + yield base_dir / "compute" / "etc" + + +@pytest.fixture(scope="function") +def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + # we are in remote env and do not have neon binaries locally + # this is the case for benchmarks run on self-hosted runner + return + + # Find the neon binaries. + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + binpath = base_dir / "target" / build_type + log.info(f"neon_binpath is {binpath}") + + if not (binpath / "pageserver").exists(): + raise Exception(f"neon binaries not found at '{binpath}'") + + yield binpath.absolute() + + +@pytest.fixture(scope="session") +def compatibility_snapshot_dir() -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + return + compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") + assert ( + compatibility_snapshot_dir_env is not None + ), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" + compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() + yield compatibility_snapshot_dir + + +@pytest.fixture(scope="session") +def compatibility_neon_binpath() -> Optional[Iterator[Path]]: + if os.getenv("REMOTE_ENV"): + return + comp_binpath = None + if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"): + comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute() + yield comp_binpath + + +@pytest.fixture(scope="session") +def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: + if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): + distrib_dir = Path(env_postgres_bin).resolve() + else: + distrib_dir = base_dir / "pg_install" + + log.info(f"pg_distrib_dir is {distrib_dir}") + yield distrib_dir + + +@pytest.fixture(scope="session") +def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]: + compat_distrib_dir = None + if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"): + compat_distrib_dir = Path(env_compat_postgres_bin).resolve() + if not compat_distrib_dir.exists(): + raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}") + + if compat_distrib_dir: + log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}") + yield compat_distrib_dir + + +@pytest.fixture(scope="session") +def top_output_dir(base_dir: Path) -> Iterator[Path]: + # Compute the top-level directory for all tests. + if env_test_output := os.environ.get("TEST_OUTPUT"): + output_dir = Path(env_test_output).resolve() + else: + output_dir = base_dir / DEFAULT_OUTPUT_DIR + output_dir.mkdir(exist_ok=True) + + log.info(f"top_output_dir is {output_dir}") + yield output_dir + + +# This is autouse, so the test output directory always gets created, even +# if a test doesn't put anything there. +# +# NB: we request the overlay dir fixture so the fixture does its cleanups +@pytest.fixture(scope="function", autouse=True) +def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]: + """Create the working directory for an individual test.""" + + # one directory per test + test_dir = get_test_output_dir(request, top_output_dir) + log.info(f"test_output_dir is {test_dir}") + shutil.rmtree(test_dir, ignore_errors=True) + test_dir.mkdir() + + yield test_dir + + # Allure artifacts creation might involve the creation of `.tar.zst` archives, + # which aren't going to be used if Allure results collection is not enabled + # (i.e. --alluredir is not set). + # Skip `allure_attach_from_dir` in this case + if not request.config.getoption("--alluredir"): + return + + preserve_database_files = False + for k, v in request.node.user_properties: + # NB: the neon_env_builder fixture uses this fixture (test_output_dir). + # So, neon_env_builder's cleanup runs before here. + # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. + if k == "preserve_database_files": + assert isinstance(v, bool) + preserve_database_files = v + + allure_attach_from_dir(test_dir, preserve_database_files) + + +class FileAndThreadLock: + def __init__(self, path: Path): + self.path = path + self.thread_lock = threading.Lock() + self.fd: Optional[int] = None + + def __enter__(self): + self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) + # lock thread lock before file lock so that there's no race + # around flocking / funlocking the file lock + self.thread_lock.acquire() + flock(self.fd, LOCK_EX) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + assert self.fd is not None + assert self.thread_lock.locked() # ... by us + flock(self.fd, LOCK_UN) + self.thread_lock.release() + os.close(self.fd) + self.fd = None + + +class SnapshotDirLocked: + def __init__(self, parent: SnapshotDir): + self._parent = parent + + def is_initialized(self): + # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. + # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. + return self._parent.marker_file_path.exists() + + def set_initialized(self): + self._parent.marker_file_path.write_text("") + + @property + def path(self) -> Path: + return self._parent.path / "snapshot" + + +class SnapshotDir: + _path: Path + + def __init__(self, path: Path): + self._path = path + assert self._path.is_dir() + self._lock = FileAndThreadLock(self.lock_file_path) + + @property + def path(self) -> Path: + return self._path + + @property + def lock_file_path(self) -> Path: + return self._path / "initializing.flock" + + @property + def marker_file_path(self) -> Path: + return self._path / "initialized.marker" + + def __enter__(self) -> SnapshotDirLocked: + self._lock.__enter__() + return SnapshotDirLocked(self) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + self._lock.__exit__(exc_type, exc_value, exc_traceback) + + +def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir: + snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) + snapshot_dir_path.mkdir(exist_ok=True, parents=True) + return SnapshotDir(snapshot_dir_path) + + +@pytest.fixture(scope="function") +def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: + """ + Idempotently create a test's overlayfs mount state directory. + If the functionality isn't enabled via env var, returns None. + + The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). + """ + + if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: + return None + + overlay_dir = get_test_overlay_dir(request, top_output_dir) + log.info(f"test_overlay_dir is {overlay_dir}") + + overlay_dir.mkdir(exist_ok=True) + # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` + for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): + cmd = ["sudo", "umount", str(mountpoint)] + log.info( + f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" + ) + subprocess.run(cmd, capture_output=True, check=True) + # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. + cmd = ["sudo", "rm", "-rf", str(overlay_dir)] + subprocess.run(cmd, capture_output=True, check=True) + + overlay_dir.mkdir() + + return overlay_dir + + # no need to clean up anything: on clean shutdown, + # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup + # and on unclean shutdown, this function will take care of it + # on the next test run diff --git a/test_runner/fixtures/pg_stats.py b/test_runner/fixtures/pg_stats.py index adb3a7730e..d334d07b2b 100644 --- a/test_runner/fixtures/pg_stats.py +++ b/test_runner/fixtures/pg_stats.py @@ -1,15 +1,16 @@ +from __future__ import annotations + from functools import cached_property -from typing import List import pytest class PgStatTable: table: str - columns: List[str] + columns: list[str] additional_query: str - def __init__(self, table: str, columns: List[str], filter_query: str = ""): + def __init__(self, table: str, columns: list[str], filter_query: str = ""): self.table = table self.columns = columns self.additional_query = filter_query @@ -20,7 +21,7 @@ class PgStatTable: @pytest.fixture(scope="function") -def pg_stats_rw() -> List[PgStatTable]: +def pg_stats_rw() -> list[PgStatTable]: return [ PgStatTable( "pg_stat_database", @@ -31,7 +32,7 @@ def pg_stats_rw() -> List[PgStatTable]: @pytest.fixture(scope="function") -def pg_stats_ro() -> List[PgStatTable]: +def pg_stats_ro() -> list[PgStatTable]: return [ PgStatTable( "pg_stat_database", ["tup_returned", "tup_fetched"], "WHERE datname='postgres'" @@ -40,7 +41,7 @@ def pg_stats_ro() -> List[PgStatTable]: @pytest.fixture(scope="function") -def pg_stats_wo() -> List[PgStatTable]: +def pg_stats_wo() -> list[PgStatTable]: return [ PgStatTable( "pg_stat_database", @@ -51,7 +52,7 @@ def pg_stats_wo() -> List[PgStatTable]: @pytest.fixture(scope="function") -def pg_stats_wal() -> List[PgStatTable]: +def pg_stats_wal() -> list[PgStatTable]: return [ PgStatTable( "pg_stat_wal", diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 258935959b..01f0245665 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import enum import os -from typing import Optional +from typing import TYPE_CHECKING import pytest +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Optional + """ This fixture is used to determine which version of Postgres to use for tests. @@ -22,10 +29,12 @@ class PgVersion(str, enum.Enum): NOT_SET = "<-POSTRGRES VERSION IS NOT SET->" # Make it less confusing in logs + @override def __repr__(self) -> str: return f"'{self.value}'" # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums + @override def __str__(self) -> str: return self.value @@ -36,7 +45,8 @@ class PgVersion(str, enum.Enum): return f"v{self.value}" @classmethod - def _missing_(cls, value) -> Optional["PgVersion"]: + @override + def _missing_(cls, value: object) -> Optional[PgVersion]: known_values = {v.value for _, v in cls.__members__.items()} # Allow passing version as a string with "v" prefix (e.g. "v14") diff --git a/test_runner/fixtures/port_distributor.py b/test_runner/fixtures/port_distributor.py index fd808d7a5f..df0eb2a809 100644 --- a/test_runner/fixtures/port_distributor.py +++ b/test_runner/fixtures/port_distributor.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import re import socket from contextlib import closing -from typing import Dict, Union +from typing import TYPE_CHECKING from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Union + def can_bind(host: str, port: int) -> bool: """ @@ -24,7 +29,7 @@ def can_bind(host: str, port: int) -> bool: sock.bind((host, port)) sock.listen() return True - except socket.error: + except OSError: log.info(f"Port {port} is in use, skipping") return False finally: @@ -34,7 +39,7 @@ def can_bind(host: str, port: int) -> bool: class PortDistributor: def __init__(self, base_port: int, port_number: int): self.iterator = iter(range(base_port, base_port + port_number)) - self.port_map: Dict[int, int] = {} + self.port_map: dict[int, int] = {} def get_port(self) -> int: for port in self.iterator: @@ -54,10 +59,7 @@ class PortDistributor: if isinstance(value, int): return self._replace_port_int(value) - if isinstance(value, str): - return self._replace_port_str(value) - - raise TypeError(f"unsupported type {type(value)} of {value=}") + return self._replace_port_str(value) def _replace_port_int(self, value: int) -> int: known_port = self.port_map.get(value) @@ -70,7 +72,7 @@ class PortDistributor: # Use regex to find port in a string # urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432" # See https://bugs.python.org/issue27657 - ports = re.findall(r":(\d+)(?:/|$)", value) + ports: list[str] = re.findall(r":(\d+)(?:/|$)", value) assert len(ports) == 1, f"can't find port in {value}" port_int = int(ports[0]) diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index 1b6c3c23ba..7024953661 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -1,21 +1,28 @@ +from __future__ import annotations + import enum import hashlib import json import os import re -import subprocess from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Union import boto3 import toml +from moto.server import ThreadedMotoServer from mypy_boto3_s3 import S3Client +from typing_extensions import override from fixtures.common_types import TenantId, TenantShardId, TimelineId from fixtures.log_helper import log from fixtures.pageserver.common_types import IndexPartDump +if TYPE_CHECKING: + from typing import Any, Optional + + TIMELINE_INDEX_PART_FILE_NAME = "index_part.json" TENANT_HEATMAP_FILE_NAME = "heatmap-v1.json" @@ -30,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum): EXTENSIONS = "ext" SAFEKEEPER = "safekeeper" + @override def __str__(self) -> str: return self.value @@ -37,7 +45,6 @@ class RemoteStorageUser(str, enum.Enum): class MockS3Server: """ Starts a mock S3 server for testing on a port given, errors if the server fails to start or exits prematurely. - Relies that `poetry` and `moto` server are installed, since it's the way the tests are run. Also provides a set of methods to derive the connection properties from and the method to kill the underlying server. """ @@ -47,22 +54,8 @@ class MockS3Server: port: int, ): self.port = port - - # XXX: do not use `shell=True` or add `exec ` to the command here otherwise. - # We use `self.subprocess.kill()` to shut down the server, which would not "just" work in Linux - # if a process is started from the shell process. - self.subprocess = subprocess.Popen(["poetry", "run", "moto_server", f"-p{port}"]) - error = None - try: - return_code = self.subprocess.poll() - if return_code is not None: - error = f"expected mock s3 server to run but it exited with code {return_code}. stdout: '{self.subprocess.stdout}', stderr: '{self.subprocess.stderr}'" - except Exception as e: - error = f"expected mock s3 server to start but it failed with exception: {e}. stdout: '{self.subprocess.stdout}', stderr: '{self.subprocess.stderr}'" - if error is not None: - log.error(error) - self.kill() - raise RuntimeError("failed to start s3 mock server") + self.server = ThreadedMotoServer(port=port) + self.server.start() def endpoint(self) -> str: return f"http://127.0.0.1:{self.port}" @@ -77,7 +70,7 @@ class MockS3Server: return "test" def kill(self): - self.subprocess.kill() + self.server.stop() @dataclass @@ -90,11 +83,13 @@ class LocalFsStorage: def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path: return self.tenant_path(tenant_id) / "timelines" / str(timeline_id) - def timeline_latest_generation(self, tenant_id, timeline_id): + def timeline_latest_generation( + self, tenant_id: TenantId, timeline_id: TimelineId + ) -> Optional[int]: timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id)) index_parts = [f for f in timeline_files if f.startswith("index_part")] - def parse_gen(filename): + def parse_gen(filename: str) -> Optional[int]: log.info(f"parsing index_part '{filename}'") parts = filename.split("-") if len(parts) == 2: @@ -102,7 +97,7 @@ class LocalFsStorage: else: return None - generations = sorted([parse_gen(f) for f in index_parts]) + generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore if len(generations) == 0: raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}") return generations[-1] @@ -131,18 +126,18 @@ class LocalFsStorage: filename = f"{local_name}-{generation:08x}" return self.timeline_path(tenant_id, timeline_id) / filename - def index_content(self, tenant_id: TenantId, timeline_id: TimelineId): + def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any: with self.index_path(tenant_id, timeline_id).open("r") as f: return json.load(f) def heatmap_path(self, tenant_id: TenantId) -> Path: return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME - def heatmap_content(self, tenant_id): + def heatmap_content(self, tenant_id: TenantId) -> Any: with self.heatmap_path(tenant_id).open("r") as f: return json.load(f) - def to_toml_dict(self) -> Dict[str, Any]: + def to_toml_dict(self) -> dict[str, Any]: return { "local_path": str(self.root), } @@ -175,7 +170,7 @@ class S3Storage: """formatting deserialized with humantime crate, for example "1s".""" custom_timeout: Optional[str] = None - def access_env_vars(self) -> Dict[str, str]: + def access_env_vars(self) -> dict[str, str]: if self.aws_profile is not None: env = { "AWS_PROFILE": self.aws_profile, @@ -204,7 +199,7 @@ class S3Storage: } ) - def to_toml_dict(self) -> Dict[str, Any]: + def to_toml_dict(self) -> dict[str, Any]: rv = { "bucket_name": self.bucket_name, "bucket_region": self.bucket_region, @@ -279,7 +274,7 @@ class S3Storage: ) -> str: return f"{self.tenant_path(tenant_id)}/timelines/{timeline_id}" - def get_latest_index_key(self, index_keys: List[str]) -> str: + def get_latest_index_key(self, index_keys: list[str]) -> str: """ Gets the latest index file key. @@ -306,7 +301,7 @@ class S3Storage: def heatmap_key(self, tenant_id: TenantId) -> str: return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}" - def heatmap_content(self, tenant_id: TenantId): + def heatmap_content(self, tenant_id: TenantId) -> Any: r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id)) return json.loads(r["Body"].read().decode("utf-8")) @@ -326,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum): def configure( self, repo_dir: Path, - mock_s3_server, + mock_s3_server: MockS3Server, run_id: str, test_name: str, user: RemoteStorageUser, @@ -419,7 +414,7 @@ class RemoteStorageKind(str, enum.Enum): ) -def available_remote_storages() -> List[RemoteStorageKind]: +def available_remote_storages() -> list[RemoteStorageKind]: remote_storages = [RemoteStorageKind.LOCAL_FS, RemoteStorageKind.MOCK_S3] if os.getenv("ENABLE_REAL_S3_REMOTE_STORAGE") is not None: remote_storages.append(RemoteStorageKind.REAL_S3) @@ -429,7 +424,7 @@ def available_remote_storages() -> List[RemoteStorageKind]: return remote_storages -def available_s3_storages() -> List[RemoteStorageKind]: +def available_s3_storages() -> list[RemoteStorageKind]: remote_storages = [RemoteStorageKind.MOCK_S3] if os.getenv("ENABLE_REAL_S3_REMOTE_STORAGE") is not None: remote_storages.append(RemoteStorageKind.REAL_S3) @@ -459,16 +454,10 @@ def default_remote_storage() -> RemoteStorageKind: return RemoteStorageKind.LOCAL_FS -def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> Dict[str, Any]: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - +def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]: return remote_storage.to_toml_dict() # serialize as toml inline table def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_inline_table() diff --git a/test_runner/fixtures/safekeeper/__init__.py b/test_runner/fixtures/safekeeper/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/fixtures/safekeeper/__init__.py +++ b/test_runner/fixtures/safekeeper/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/fixtures/safekeeper/http.py b/test_runner/fixtures/safekeeper/http.py index 7f170eeea3..5d9a3bd149 100644 --- a/test_runner/fixtures/safekeeper/http.py +++ b/test_runner/fixtures/safekeeper/http.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING import pytest import requests @@ -10,6 +12,9 @@ from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.utils import wait_until +if TYPE_CHECKING: + from typing import Any, Optional, Union + # Walreceiver as returned by sk's timeline status endpoint. @dataclass @@ -29,7 +34,7 @@ class SafekeeperTimelineStatus: backup_lsn: Lsn peer_horizon_lsn: Lsn remote_consistent_lsn: Lsn - walreceivers: List[Walreceiver] + walreceivers: list[Walreceiver] class SafekeeperMetrics(Metrics): @@ -57,7 +62,7 @@ class TermBumpResponse: current_term: int @classmethod - def from_json(cls, d: Dict[str, Any]) -> "TermBumpResponse": + def from_json(cls, d: dict[str, Any]) -> TermBumpResponse: return TermBumpResponse( previous_term=d["previous_term"], current_term=d["current_term"], @@ -93,7 +98,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): if not self.is_testing_enabled: pytest.skip("safekeeper was built without 'testing' feature") - def configure_failpoints(self, config_strings: Union[Tuple[str, str], List[Tuple[str, str]]]): + def configure_failpoints(self, config_strings: Union[tuple[str, str], list[tuple[str, str]]]): self.is_testing_enabled_or_skip() if isinstance(config_strings, tuple): @@ -113,14 +118,14 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert res_json is None return res_json - def tenant_delete_force(self, tenant_id: TenantId) -> Dict[Any, Any]: + def tenant_delete_force(self, tenant_id: TenantId) -> dict[Any, Any]: res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) return res_json - def timeline_list(self) -> List[TenantTimelineId]: + def timeline_list(self) -> list[TenantTimelineId]: res = self.get(f"http://localhost:{self.port}/v1/tenant/timeline") res.raise_for_status() resj = res.json() @@ -178,7 +183,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): # only_local doesn't remove segments in the remote storage. def timeline_delete( self, tenant_id: TenantId, timeline_id: TimelineId, only_local: bool = False - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: res = self.delete( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}", params={ @@ -190,7 +195,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def debug_dump(self, params: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + def debug_dump(self, params: Optional[dict[str, str]] = None) -> dict[str, Any]: params = params or {} res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) res.raise_for_status() @@ -199,7 +204,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): return res_json def debug_dump_timeline( - self, timeline_id: TimelineId, params: Optional[Dict[str, str]] = None + self, timeline_id: TimelineId, params: Optional[dict[str, str]] = None ) -> Any: params = params or {} params["timeline_id"] = str(timeline_id) @@ -214,14 +219,14 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): dump = self.debug_dump_timeline(timeline_id, {"dump_control_file": "true"}) return dump["control_file"]["eviction_state"] - def pull_timeline(self, body: Dict[str, Any]) -> Dict[str, Any]: + def pull_timeline(self, body: dict[str, Any]) -> dict[str, Any]: res = self.post(f"http://localhost:{self.port}/v1/pull_timeline", json=body) res.raise_for_status() res_json = res.json() assert isinstance(res_json, dict) return res_json - def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: Dict[str, Any]): + def copy_timeline(self, tenant_id: TenantId, timeline_id: TimelineId, body: dict[str, Any]): res = self.post( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/copy", json=body, @@ -232,8 +237,8 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): self, tenant_id: TenantId, timeline_id: TimelineId, - patch: Dict[str, Any], - ) -> Dict[str, Any]: + patch: dict[str, Any], + ) -> dict[str, Any]: res = self.patch( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/control_file", json={ @@ -255,7 +260,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): def timeline_digest( self, tenant_id: TenantId, timeline_id: TimelineId, from_lsn: Lsn, until_lsn: Lsn - ) -> Dict[str, Any]: + ) -> dict[str, Any]: res = self.get( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/digest", params={ diff --git a/test_runner/fixtures/safekeeper/utils.py b/test_runner/fixtures/safekeeper/utils.py index 2a081c6ccb..0246916470 100644 --- a/test_runner/fixtures/safekeeper/utils.py +++ b/test_runner/fixtures/safekeeper/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log from fixtures.safekeeper.http import SafekeeperHttpClient diff --git a/test_runner/fixtures/slow.py b/test_runner/fixtures/slow.py index ae0e87b553..4c6372d515 100644 --- a/test_runner/fixtures/slow.py +++ b/test_runner/fixtures/slow.py @@ -1,9 +1,15 @@ -from typing import Any, List +from __future__ import annotations + +from typing import TYPE_CHECKING import pytest from _pytest.config import Config from _pytest.config.argparsing import Parser +if TYPE_CHECKING: + from typing import Any + + """ This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow tests are excluded. They need to be specifically requested with the --runslow flag in @@ -21,7 +27,7 @@ def pytest_configure(config: Config): config.addinivalue_line("markers", "slow: mark test as slow to run") -def pytest_collection_modifyitems(config: Config, items: List[Any]): +def pytest_collection_modifyitems(config: Config, items: list[Any]): if config.getoption("--runslow"): # --runslow given in cli: do not skip slow tests return diff --git a/test_runner/fixtures/storage_controller_proxy.py b/test_runner/fixtures/storage_controller_proxy.py index 3477f8b1f2..c174358ef5 100644 --- a/test_runner/fixtures/storage_controller_proxy.py +++ b/test_runner/fixtures/storage_controller_proxy.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import re -from typing import Any, Optional +from typing import TYPE_CHECKING import pytest import requests @@ -10,6 +12,9 @@ from werkzeug.wrappers.response import Response from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Optional + class StorageControllerProxy: def __init__(self, server: HTTPServer): @@ -32,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response: @pytest.fixture(scope="function") -def storage_controller_proxy(make_httpserver): +def storage_controller_proxy(make_httpserver: HTTPServer): """ Proxies requests into the storage controller to the currently selected storage controller instance via `StorageControllerProxy.route_to`. @@ -46,7 +51,7 @@ def storage_controller_proxy(make_httpserver): log.info(f"Storage controller proxy listening on {self.listen}") - def handler(request: Request): + def handler(request: Request) -> Response: if self.route_to is None: log.info(f"Storage controller proxy has no routing configured for {request.url}") return Response("Routing not configured", status=503) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 10e8412b19..76575d330c 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import enum import json @@ -7,27 +9,16 @@ import subprocess import tarfile import threading import time +from collections.abc import Iterable from hashlib import sha256 from pathlib import Path -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, TypeVar from urllib.parse import urlencode import allure import zstandard from psycopg2.extensions import cursor +from typing_extensions import override from fixtures.log_helper import log from fixtures.pageserver.common_types import ( @@ -36,29 +27,47 @@ from fixtures.pageserver.common_types import ( ) if TYPE_CHECKING: + from collections.abc import Iterable + from typing import IO, Optional + + from fixtures.common_types import TimelineId from fixtures.neon_fixtures import PgBin -from fixtures.common_types import TimelineId + + WaitUntilRet = TypeVar("WaitUntilRet") + Fn = TypeVar("Fn", bound=Callable[..., Any]) - - -def get_self_dir() -> Path: - """Get the path to the directory where this script lives.""" - return Path(__file__).resolve().parent +COMPONENT_BINARIES = { + "storage_controller": ("storage_controller",), + "storage_broker": ("storage_broker",), + "compute": ("compute_ctl",), + "safekeeper": ("safekeeper",), + "pageserver": ("pageserver", "pagectl"), +} +# Disable auto-formatting for better readability +# fmt: off +VERSIONS_COMBINATIONS = ( + {"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"}, +) +# fmt: on def subprocess_capture( capture_dir: Path, - cmd: List[str], + cmd: list[str], *, - check=False, - echo_stderr=False, - echo_stdout=False, - capture_stdout=False, - timeout=None, - with_command_header=True, + check: bool = False, + echo_stderr: bool = False, + echo_stdout: bool = False, + capture_stdout: bool = False, + timeout: Optional[float] = None, + with_command_header: bool = True, **popen_kwargs: Any, -) -> Tuple[str, Optional[str], int]: +) -> tuple[str, Optional[str], int]: """Run a process and bifurcate its output to files and the `log` logger stderr and stdout are always captured in files. They are also optionally @@ -93,6 +102,7 @@ def subprocess_capture( self.capture = capture self.captured = "" + @override def run(self): first = with_command_header for line in self.in_file: @@ -103,7 +113,7 @@ def subprocess_capture( first = False # prefix the files with the command line so that we can # later understand which file is for what command - self.out_file.write((f"# {' '.join(cmd)}\n\n").encode("utf-8")) + self.out_file.write((f"# {' '.join(cmd)}\n\n").encode()) # Only bother decoding if we are going to do something more than stream to a file if self.echo or self.capture: @@ -171,13 +181,13 @@ def global_counter() -> int: return _global_counter -def print_gc_result(row: Dict[str, Any]): +def print_gc_result(row: dict[str, Any]): log.info("GC duration {elapsed} ms".format_map(row)) log.info( - " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" - " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map( - row - ) + ( + " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" + " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}" + ).format_map(row) ) @@ -235,7 +245,7 @@ def get_scale_for_db(size_mb: int) -> int: return round(0.06689 * size_mb - 0.5) -ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile( r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)" ) @@ -298,7 +308,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz" def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int): """Add links to server logs in Grafana to Allure report""" - links = {} + links: dict[str, str] = {} # We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build endpoint_id, region_id, _ = host.split(".", 2) @@ -309,7 +319,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, "proxy logs": f'{{neon_service="proxy-scram", neon_region="{region_id}"}}', } - params: Dict[str, Any] = { + params: dict[str, Any] = { "datasource": LOGS_STAGING_DATASOURCE_ID, "queries": [ { @@ -350,7 +360,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, def start_in_background( - command: list[str], cwd: Path, log_file_name: str, is_started: Fn + command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet] ) -> subprocess.Popen[bytes]: """Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors.""" @@ -385,14 +395,11 @@ def start_in_background( return spawned_process -WaitUntilRet = TypeVar("WaitUntilRet") - - def wait_until( number_of_iterations: int, interval: float, func: Callable[[], WaitUntilRet], - show_intermediate_error=False, + show_intermediate_error: bool = False, ) -> WaitUntilRet: """ Wait until 'func' returns successfully, without exception. Returns the @@ -425,7 +432,7 @@ def assert_ge(a, b) -> None: assert a >= b -def run_pg_bench_small(pg_bin: "PgBin", connstr: str): +def run_pg_bench_small(pg_bin: PgBin, connstr: str): """ Fast way to populate data. For more layers consider combining with these tenant settings: @@ -470,10 +477,10 @@ def humantime_to_ms(humantime: str) -> float: return round(total_ms, 3) -def scan_log_for_errors(input: Iterable[str], allowed_errors: List[str]) -> List[Tuple[int, str]]: +def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]: # FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py error_or_warn = re.compile(r"\s(ERROR|WARN)") - errors = [] + errors: list[tuple[int, str]] = [] for lineno, line in enumerate(input, start=1): if len(line) == 0: continue @@ -493,7 +500,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: List[str]) -> List return errors -def assert_no_errors(log_file, service, allowed_errors): +def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]): if not log_file.exists(): log.warning(f"Skipping {service} log check: {log_file} does not exist") return @@ -513,14 +520,16 @@ class AuxFileStore(str, enum.Enum): V2 = "v2" CrossValidation = "cross-validation" + @override def __repr__(self) -> str: return f"'aux-{self.value}'" + @override def __str__(self) -> str: return f"'aux-{self.value}'" -def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: Set[str]): +def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str]): """ This is essentially: @@ -534,7 +543,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: Set[str """ started_at = time.time() - def hash_extracted(reader: Union[IO[bytes], None]) -> bytes: + def hash_extracted(reader: Optional[IO[bytes]]) -> bytes: assert reader is not None digest = sha256(usedforsecurity=False) while True: @@ -544,7 +553,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: Set[str digest.update(buf) return digest.digest() - def build_hash_list(p: Path) -> List[Tuple[str, bytes]]: + def build_hash_list(p: Path) -> list[tuple[str, bytes]]: with tarfile.open(p) as f: matching_files = (info for info in f if info.isreg() and info.name not in skip_files) ret = list( @@ -559,7 +568,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: Set[str right_list ), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}" - mismatching = set() + mismatching: set[str] = set() for left_tuple, right_tuple in zip(left_list, right_list): left_path, left_hash = left_tuple @@ -584,6 +593,7 @@ class PropagatingThread(threading.Thread): Simple Thread wrapper with join() propagating the possible exception in the thread. """ + @override def run(self): self.exc = None try: @@ -591,8 +601,9 @@ class PropagatingThread(threading.Thread): except BaseException as e: self.exc = e - def join(self, timeout=None): - super(PropagatingThread, self).join(timeout) + @override + def join(self, timeout: Optional[float] = None) -> Any: + super().join(timeout) if self.exc: raise self.exc return self.ret @@ -613,3 +624,19 @@ def human_bytes(amt: float) -> str: amt = amt / 1024 raise RuntimeError("unreachable") + + +def allpairs_versions(): + """ + Returns a dictionary with arguments for pytest parametrize + to test the compatibility with the previous version of Neon components + combinations were pre-computed to test all the pairs of the components with + the different versions. + """ + ids = [] + for pair in VERSIONS_COMBINATIONS: + cur_id = [] + for component in sorted(pair.keys()): + cur_id.append(pair[component][0]) + ids.append(f"combination_{''.join(cur_id)}") + return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids} diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index 1ea0267e87..e869c43185 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import threading -from typing import Any, Optional +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log @@ -12,6 +14,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_for_last_record_lsn +if TYPE_CHECKING: + from typing import Any, Optional + # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. ENDPOINT_LOCK = threading.Lock() @@ -98,7 +103,7 @@ class Workload: self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id ) - def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True): + def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True): endpoint = self.endpoint(pageserver_id) start = self.expect_rows end = start + n - 1 @@ -119,7 +124,9 @@ class Workload: else: return False - def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True): + def churn_rows( + self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True + ): assert self.expect_rows >= n max_iters = 10 diff --git a/test_runner/logical_repl/test_clickhouse.py b/test_runner/logical_repl/test_clickhouse.py index c5ed9bc8af..8e03bbe5d4 100644 --- a/test_runner/logical_repl/test_clickhouse.py +++ b/test_runner/logical_repl/test_clickhouse.py @@ -2,6 +2,8 @@ Test the logical replication in Neon with ClickHouse as a consumer """ +from __future__ import annotations + import hashlib import os import time diff --git a/test_runner/logical_repl/test_debezium.py b/test_runner/logical_repl/test_debezium.py index 5426a06ca1..d2cb087c92 100644 --- a/test_runner/logical_repl/test_debezium.py +++ b/test_runner/logical_repl/test_debezium.py @@ -2,6 +2,8 @@ Test the logical replication in Neon with Debezium as a consumer """ +from __future__ import annotations + import json import os import time diff --git a/test_runner/performance/__init__.py b/test_runner/performance/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/performance/__init__.py +++ b/test_runner/performance/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/performance/pageserver/__init__.py b/test_runner/performance/pageserver/__init__.py index e69de29bb2..9d48db4f9f 100644 --- a/test_runner/performance/pageserver/__init__.py +++ b/test_runner/performance/pageserver/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/test_runner/performance/pageserver/interactive/__init__.py b/test_runner/performance/pageserver/interactive/__init__.py index 29644c240e..1133d116a5 100644 --- a/test_runner/performance/pageserver/interactive/__init__.py +++ b/test_runner/performance/pageserver/interactive/__init__.py @@ -6,3 +6,5 @@ but then debug a performance problem interactively. It's kind of an abuse of the test framework, but, it's our only tool right now to automate a complex test bench setup. """ + +from __future__ import annotations diff --git a/test_runner/performance/pageserver/interactive/test_many_small_tenants.py b/test_runner/performance/pageserver/interactive/test_many_small_tenants.py index 0a5a2c10d6..4931295beb 100644 --- a/test_runner/performance/pageserver/interactive/test_many_small_tenants.py +++ b/test_runner/performance/pageserver/interactive/test_many_small_tenants.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pdb diff --git a/test_runner/performance/pageserver/pagebench/__init__.py b/test_runner/performance/pageserver/pagebench/__init__.py index 9f5e45c0a0..4ed774cf2d 100644 --- a/test_runner/performance/pageserver/pagebench/__init__.py +++ b/test_runner/performance/pageserver/pagebench/__init__.py @@ -8,3 +8,5 @@ instead of benchmarking the full stack. See https://github.com/neondatabase/neon/issues/5771 for the context in which this was developed. """ + +from __future__ import annotations diff --git a/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py b/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py index c3ba5afc24..efd423104d 100644 --- a/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py +++ b/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio import json from pathlib import Path -from typing import Any, Dict, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -13,6 +15,9 @@ from performance.pageserver.util import ( setup_pageserver_with_tenants, ) +if TYPE_CHECKING: + from typing import Any + @pytest.mark.parametrize("duration", [30]) @pytest.mark.parametrize("pgbench_scale", [get_scale_for_db(200)]) @@ -29,7 +34,7 @@ def test_basebackup_with_high_slru_count( def record(metric, **kwargs): zenbenchmark.record(metric_name=f"pageserver_basebackup.{metric}", **kwargs) - params: Dict[str, Tuple[Any, Dict[str, Any]]] = {} + params: dict[str, tuple[Any, dict[str, Any]]] = {} # params from fixtures params.update( @@ -157,7 +162,7 @@ def run_benchmark(env: NeonEnv, pg_bin: PgBin, record, duration_secs: int): results_path = Path(basepath + ".stdout") log.info(f"Benchmark results at: {results_path}") - with open(results_path, "r") as f: + with open(results_path) as f: results = json.load(f) log.info(f"Results:\n{json.dumps(results, sort_keys=True, indent=2)}") diff --git a/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py b/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py index 9ad6e7907c..8738f93a06 100644 --- a/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py +++ b/test_runner/performance/pageserver/pagebench/test_ondemand_download_churn.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json from pathlib import Path -from typing import Any, Dict, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -14,6 +16,9 @@ from fixtures.neon_fixtures import ( from fixtures.remote_storage import s3_storage from fixtures.utils import humantime_to_ms +if TYPE_CHECKING: + from typing import Any + @pytest.mark.parametrize("duration", [30]) @pytest.mark.parametrize("io_engine", ["tokio-epoll-uring", "std-fs"]) @@ -30,7 +35,7 @@ def test_download_churn( def record(metric, **kwargs): zenbenchmark.record(metric_name=f"pageserver_ondemand_download_churn.{metric}", **kwargs) - params: Dict[str, Tuple[Any, Dict[str, Any]]] = {} + params: dict[str, tuple[Any, dict[str, Any]]] = {} # params from fixtures params.update( @@ -134,7 +139,7 @@ def run_benchmark( results_path = Path(basepath + ".stdout") log.info(f"Benchmark results at: {results_path}") - with open(results_path, "r") as f: + with open(results_path) as f: results = json.load(f) log.info(f"Results:\n{json.dumps(results, sort_keys=True, indent=2)}") diff --git a/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py b/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py index 97eed88473..c038fc3fd2 100644 --- a/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py +++ b/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json import os from pathlib import Path -from typing import Any, Dict, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -18,6 +20,10 @@ from performance.pageserver.util import ( setup_pageserver_with_tenants, ) +if TYPE_CHECKING: + from typing import Any + + # The following tests use pagebench "getpage at latest LSN" to characterize the throughput of the pageserver. # originally there was a single test named `test_pageserver_max_throughput_getpage_at_latest_lsn`` # so you still see some references to this name in the code. @@ -92,7 +98,7 @@ def setup_and_run_pagebench_benchmark( metric_name=f"pageserver_max_throughput_getpage_at_latest_lsn.{metric}", **kwargs ) - params: Dict[str, Tuple[Any, Dict[str, Any]]] = {} + params: dict[str, tuple[Any, dict[str, Any]]] = {} # params from fixtures params.update( @@ -225,7 +231,7 @@ def run_pagebench_benchmark( results_path = Path(basepath + ".stdout") log.info(f"Benchmark results at: {results_path}") - with open(results_path, "r") as f: + with open(results_path) as f: results = json.load(f) log.info(f"Results:\n{json.dumps(results, sort_keys=True, indent=2)}") diff --git a/test_runner/performance/pageserver/util.py b/test_runner/performance/pageserver/util.py index 88296a7fbd..227319c425 100644 --- a/test_runner/performance/pageserver/util.py +++ b/test_runner/performance/pageserver/util.py @@ -2,7 +2,9 @@ Utilities used by all code in this sub-directory """ -from typing import Any, Callable, Dict, Optional, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import fixtures.pageserver.many_tenants as many_tenants from fixtures.common_types import TenantId, TimelineId @@ -13,6 +15,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_until_all_tenants_state +if TYPE_CHECKING: + from typing import Any, Callable, Optional + def ensure_pageserver_ready_for_benchmarking(env: NeonEnv, n_tenants: int): """ @@ -40,7 +45,7 @@ def setup_pageserver_with_tenants( neon_env_builder: NeonEnvBuilder, name: str, n_tenants: int, - setup: Callable[[NeonEnv], Tuple[TenantId, TimelineId, Dict[str, Any]]], + setup: Callable[[NeonEnv], tuple[TenantId, TimelineId, dict[str, Any]]], timeout_in_seconds: Optional[int] = None, ) -> NeonEnv: """ diff --git a/test_runner/performance/pgvector/loaddata.py b/test_runner/performance/pgvector/loaddata.py index 36c209aed3..207f5657fc 100644 --- a/test_runner/performance/pgvector/loaddata.py +++ b/test_runner/performance/pgvector/loaddata.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import sys from pathlib import Path -import numpy as np -import pandas as pd +import numpy as np # type: ignore [import] +import pandas as pd # type: ignore [import] import psycopg2 -from pgvector.psycopg2 import register_vector +from pgvector.psycopg2 import register_vector # type: ignore [import] from psycopg2.extras import execute_values diff --git a/test_runner/performance/test_branch_creation.py b/test_runner/performance/test_branch_creation.py index 1fdb06785b..c50c4ad432 100644 --- a/test_runner/performance/test_branch_creation.py +++ b/test_runner/performance/test_branch_creation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import re import statistics @@ -5,7 +7,6 @@ import threading import time import timeit from contextlib import closing -from typing import List import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -18,7 +19,7 @@ from fixtures.utils import wait_until from prometheus_client.samples import Sample -def _record_branch_creation_durations(neon_compare: NeonCompare, durs: List[float]): +def _record_branch_creation_durations(neon_compare: NeonCompare, durs: list[float]): neon_compare.zenbenchmark.record( "branch_creation_duration_max", max(durs), "s", MetricReport.LOWER_IS_BETTER ) @@ -66,7 +67,7 @@ def test_branch_creation_heavy_write(neon_compare: NeonCompare, n_branches: int) env.create_branch("b0", tenant_id=tenant) - threads: List[threading.Thread] = [] + threads: list[threading.Thread] = [] threads.append(threading.Thread(target=run_pgbench, args=("b0",), daemon=True)) threads[-1].start() @@ -194,7 +195,7 @@ def wait_and_record_startup_metrics( ] ) - def metrics_are_filled() -> List[Sample]: + def metrics_are_filled() -> list[Sample]: m = client.get_metrics() samples = m.query_all("pageserver_startup_duration_seconds") # we should not have duplicate labels diff --git a/test_runner/performance/test_branching.py b/test_runner/performance/test_branching.py index 36c821795a..dbff116360 100644 --- a/test_runner/performance/test_branching.py +++ b/test_runner/performance/test_branching.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import timeit from pathlib import Path -from typing import List from fixtures.benchmark_fixture import PgBenchRunResult from fixtures.compare_fixtures import NeonCompare @@ -22,7 +23,7 @@ def test_compare_child_and_root_pgbench_perf(neon_compare: NeonCompare): env = neon_compare.env pg_bin = neon_compare.pg_bin - def run_pgbench_on_branch(branch: str, cmd: List[str]): + def run_pgbench_on_branch(branch: str, cmd: list[str]): run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() out = pg_bin.run_capture( diff --git a/test_runner/performance/test_bulk_insert.py b/test_runner/performance/test_bulk_insert.py index 69df7974b9..36090dcad7 100644 --- a/test_runner/performance/test_bulk_insert.py +++ b/test_runner/performance/test_bulk_insert.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing from fixtures.benchmark_fixture import MetricReport diff --git a/test_runner/performance/test_bulk_tenant_create.py b/test_runner/performance/test_bulk_tenant_create.py index 188ff5e3ad..15a03ba456 100644 --- a/test_runner/performance/test_bulk_tenant_create.py +++ b/test_runner/performance/test_bulk_tenant_create.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import timeit import pytest diff --git a/test_runner/performance/test_bulk_update.py b/test_runner/performance/test_bulk_update.py index 13c48e1174..6946bc66f2 100644 --- a/test_runner/performance/test_bulk_update.py +++ b/test_runner/performance/test_bulk_update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_last_flush_lsn diff --git a/test_runner/performance/test_compaction.py b/test_runner/performance/test_compaction.py index 54b17ebf8a..8868dddf39 100644 --- a/test_runner/performance/test_compaction.py +++ b/test_runner/performance/test_compaction.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing import pytest diff --git a/test_runner/performance/test_compare_pg_stats.py b/test_runner/performance/test_compare_pg_stats.py index d5dd1b4bd0..a86995d6d3 100644 --- a/test_runner/performance/test_compare_pg_stats.py +++ b/test_runner/performance/test_compare_pg_stats.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os import threading import time -from typing import List import pytest from fixtures.compare_fixtures import PgCompare @@ -23,7 +24,7 @@ def test_compare_pg_stats_rw_with_pgbench_default( seed: int, scale: int, duration: int, - pg_stats_rw: List[PgStatTable], + pg_stats_rw: list[PgStatTable], ): env = neon_with_baseline # initialize pgbench @@ -45,7 +46,7 @@ def test_compare_pg_stats_wo_with_pgbench_simple_update( seed: int, scale: int, duration: int, - pg_stats_wo: List[PgStatTable], + pg_stats_wo: list[PgStatTable], ): env = neon_with_baseline # initialize pgbench @@ -67,7 +68,7 @@ def test_compare_pg_stats_ro_with_pgbench_select_only( seed: int, scale: int, duration: int, - pg_stats_ro: List[PgStatTable], + pg_stats_ro: list[PgStatTable], ): env = neon_with_baseline # initialize pgbench @@ -89,7 +90,7 @@ def test_compare_pg_stats_wal_with_pgbench_default( seed: int, scale: int, duration: int, - pg_stats_wal: List[PgStatTable], + pg_stats_wal: list[PgStatTable], ): env = neon_with_baseline # initialize pgbench @@ -106,7 +107,7 @@ def test_compare_pg_stats_wal_with_pgbench_default( @pytest.mark.parametrize("n_tables", [1, 10]) @pytest.mark.parametrize("duration", get_durations_matrix(10)) def test_compare_pg_stats_wo_with_heavy_write( - neon_with_baseline: PgCompare, n_tables: int, duration: int, pg_stats_wo: List[PgStatTable] + neon_with_baseline: PgCompare, n_tables: int, duration: int, pg_stats_wo: list[PgStatTable] ): env = neon_with_baseline with env.pg.connect().cursor() as cur: diff --git a/test_runner/performance/test_copy.py b/test_runner/performance/test_copy.py index a91c78e867..743604a381 100644 --- a/test_runner/performance/test_copy.py +++ b/test_runner/performance/test_copy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing from io import BufferedReader, RawIOBase from typing import Optional diff --git a/test_runner/performance/test_dup_key.py b/test_runner/performance/test_dup_key.py index 60a4d91313..f7e4a629d6 100644 --- a/test_runner/performance/test_dup_key.py +++ b/test_runner/performance/test_dup_key.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing import pytest diff --git a/test_runner/performance/test_gc_feedback.py b/test_runner/performance/test_gc_feedback.py index 2ba1018b33..07f244da0c 100644 --- a/test_runner/performance/test_gc_feedback.py +++ b/test_runner/performance/test_gc_feedback.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import pytest diff --git a/test_runner/performance/test_gist_build.py b/test_runner/performance/test_gist_build.py index 45900d0c7f..e8ef59722d 100644 --- a/test_runner/performance/test_gist_build.py +++ b/test_runner/performance/test_gist_build.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing from fixtures.compare_fixtures import PgCompare diff --git a/test_runner/performance/test_hot_page.py b/test_runner/performance/test_hot_page.py index 5e97c7cddf..d025566919 100644 --- a/test_runner/performance/test_hot_page.py +++ b/test_runner/performance/test_hot_page.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing import pytest diff --git a/test_runner/performance/test_hot_table.py b/test_runner/performance/test_hot_table.py index 9a78c92ec0..792d35321d 100644 --- a/test_runner/performance/test_hot_table.py +++ b/test_runner/performance/test_hot_table.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing import pytest diff --git a/test_runner/performance/test_latency.py b/test_runner/performance/test_latency.py index 6c94ecc482..133a2cfd8a 100644 --- a/test_runner/performance/test_latency.py +++ b/test_runner/performance/test_latency.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import pytest diff --git a/test_runner/performance/test_layer_map.py b/test_runner/performance/test_layer_map.py index fb2ac14a92..8a4ad2d399 100644 --- a/test_runner/performance/test_layer_map.py +++ b/test_runner/performance/test_layer_map.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver diff --git a/test_runner/performance/test_lazy_startup.py b/test_runner/performance/test_lazy_startup.py index 5af10bc491..704073fe3b 100644 --- a/test_runner/performance/test_lazy_startup.py +++ b/test_runner/performance/test_lazy_startup.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import requests from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker diff --git a/test_runner/performance/test_parallel_copy_to.py b/test_runner/performance/test_parallel_copy_to.py index 9a0b7723ac..ddee0ebcd1 100644 --- a/test_runner/performance/test_parallel_copy_to.py +++ b/test_runner/performance/test_parallel_copy_to.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from io import BytesIO diff --git a/test_runner/performance/test_perf_olap.py b/test_runner/performance/test_perf_olap.py index aaa2f8fec2..bc4ab64105 100644 --- a/test_runner/performance/test_perf_olap.py +++ b/test_runner/performance/test_perf_olap.py @@ -1,12 +1,13 @@ +from __future__ import annotations + import os from dataclasses import dataclass -from typing import Dict, List, Tuple +from pathlib import Path import pytest from _pytest.mark import ParameterSet from fixtures.compare_fixtures import RemoteCompare from fixtures.log_helper import log -from fixtures.utils import get_self_dir @dataclass @@ -45,7 +46,7 @@ def test_clickbench_create_pg_stat_statements(remote_compare: RemoteCompare): # # Disable auto formatting for the list of queries so that it's easier to read # fmt: off -QUERIES: Tuple[LabelledQuery, ...] = ( +QUERIES: tuple[LabelledQuery, ...] = ( ### ClickBench queries: LabelledQuery("Q0", r"SELECT COUNT(*) FROM hits;"), LabelledQuery("Q1", r"SELECT COUNT(*) FROM hits WHERE AdvEngineID <> 0;"), @@ -105,7 +106,7 @@ QUERIES: Tuple[LabelledQuery, ...] = ( # # Disable auto formatting for the list of queries so that it's easier to read # fmt: off -PGVECTOR_QUERIES: Tuple[LabelledQuery, ...] = ( +PGVECTOR_QUERIES: tuple[LabelledQuery, ...] = ( LabelledQuery("PGVPREP", r"ALTER EXTENSION VECTOR UPDATE;"), LabelledQuery("PGV0", r"DROP TABLE IF EXISTS hnsw_test_table;"), LabelledQuery("PGV1", r"CREATE TABLE hnsw_test_table AS TABLE documents WITH NO DATA;"), @@ -127,7 +128,7 @@ PGVECTOR_QUERIES: Tuple[LabelledQuery, ...] = ( EXPLAIN_STRING: str = "EXPLAIN (ANALYZE, VERBOSE, BUFFERS, COSTS, SETTINGS, FORMAT JSON)" -def get_scale() -> List[str]: +def get_scale() -> list[str]: # We parametrize each tpc-h and clickbench test with scale # to distinguish them from each other, but don't really use it inside. # Databases are pre-created and passed through BENCHMARK_CONNSTR env variable. @@ -147,7 +148,7 @@ def run_psql( options = f"-cstatement_timeout=0 {env.pg.default_options.get('options', '')}" connstr = env.pg.connstr(password=None, options=options) - environ: Dict[str, str] = {} + environ: dict[str, str] = {} if password is not None: environ["PGPASSWORD"] = password @@ -185,13 +186,13 @@ def test_clickbench(query: LabelledQuery, remote_compare: RemoteCompare, scale: run_psql(remote_compare, query, times=3, explain=explain) -def tpch_queuies() -> Tuple[ParameterSet, ...]: +def tpch_queuies() -> tuple[ParameterSet, ...]: """ A list of queries to run for the TPC-H benchmark. - querues in returning tuple are ordered by the query number - pytest parameters id is adjusted to match the query id (the numbering starts from 1) """ - queries_dir = get_self_dir().parent / "performance" / "tpc-h" / "queries" + queries_dir = Path(__file__).parent / "tpc-h" / "queries" assert queries_dir.exists(), f"TPC-H queries dir not found: {queries_dir}" return tuple( diff --git a/test_runner/performance/test_perf_pgbench.py b/test_runner/performance/test_perf_pgbench.py index 6eaa29e4f8..24ff3d23fa 100644 --- a/test_runner/performance/test_perf_pgbench.py +++ b/test_runner/performance/test_perf_pgbench.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import calendar import enum import os import timeit from datetime import datetime from pathlib import Path -from typing import Dict, List import pytest from fixtures.benchmark_fixture import MetricReport, PgBenchInitResult, PgBenchRunResult @@ -26,7 +27,7 @@ def utc_now_timestamp() -> int: def init_pgbench(env: PgCompare, cmdline, password: None): - environ: Dict[str, str] = {} + environ: dict[str, str] = {} if password is not None: environ["PGPASSWORD"] = password @@ -54,7 +55,7 @@ def init_pgbench(env: PgCompare, cmdline, password: None): def run_pgbench(env: PgCompare, prefix: str, cmdline, password: None): - environ: Dict[str, str] = {} + environ: dict[str, str] = {} if password is not None: environ["PGPASSWORD"] = password @@ -177,7 +178,7 @@ def run_test_pgbench(env: PgCompare, scale: int, duration: int, workload_type: P env.report_size() -def get_durations_matrix(default: int = 45) -> List[int]: +def get_durations_matrix(default: int = 45) -> list[int]: durations = os.getenv("TEST_PG_BENCH_DURATIONS_MATRIX", default=str(default)) rv = [] for d in durations.split(","): @@ -193,7 +194,7 @@ def get_durations_matrix(default: int = 45) -> List[int]: return rv -def get_scales_matrix(default: int = 10) -> List[int]: +def get_scales_matrix(default: int = 10) -> list[int]: scales = os.getenv("TEST_PG_BENCH_SCALES_MATRIX", default=str(default)) rv = [] for s in scales.split(","): diff --git a/test_runner/performance/test_perf_pgvector_queries.py b/test_runner/performance/test_perf_pgvector_queries.py index bb3db16305..4a5ea94c4b 100644 --- a/test_runner/performance/test_perf_pgvector_queries.py +++ b/test_runner/performance/test_perf_pgvector_queries.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.compare_fixtures import PgCompare diff --git a/test_runner/performance/test_physical_replication.py b/test_runner/performance/test_physical_replication.py index 49b1176d34..14b527acca 100644 --- a/test_runner/performance/test_physical_replication.py +++ b/test_runner/performance/test_physical_replication.py @@ -18,7 +18,7 @@ from fixtures.neon_api import connection_parameters_to_env from fixtures.pg_version import PgVersion if TYPE_CHECKING: - from typing import Any, List, Optional + from typing import Any, Optional from fixtures.benchmark_fixture import NeonBenchmarker from fixtures.neon_api import NeonAPI @@ -233,7 +233,7 @@ def test_replication_start_stop( ], env=master_env, ) - replica_pgbench: List[Optional[subprocess.Popen[Any]]] = [None for _ in range(num_replicas)] + replica_pgbench: list[Optional[subprocess.Popen[Any]]] = [None for _ in range(num_replicas)] # Use the bits of iconfig to tell us which configuration we are on. For example # a iconfig of 2 is 10 in binary, indicating replica 0 is suspended and replica 1 is diff --git a/test_runner/performance/test_random_writes.py b/test_runner/performance/test_random_writes.py index c1a59ebb31..46848a8af8 100644 --- a/test_runner/performance/test_random_writes.py +++ b/test_runner/performance/test_random_writes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random from contextlib import closing diff --git a/test_runner/performance/test_seqscans.py b/test_runner/performance/test_seqscans.py index 67d4f3ae9b..36ee4eb201 100644 --- a/test_runner/performance/test_seqscans.py +++ b/test_runner/performance/test_seqscans.py @@ -1,5 +1,8 @@ # Test sequential scan speed # + +from __future__ import annotations + from contextlib import closing import pytest diff --git a/test_runner/performance/test_sharding_autosplit.py b/test_runner/performance/test_sharding_autosplit.py index 35793e41d7..caa89955e3 100644 --- a/test_runner/performance/test_sharding_autosplit.py +++ b/test_runner/performance/test_sharding_autosplit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import re from pathlib import Path diff --git a/test_runner/performance/test_startup.py b/test_runner/performance/test_startup.py index 514d8bae2a..d051717e92 100644 --- a/test_runner/performance/test_startup.py +++ b/test_runner/performance/test_startup.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import requests from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/performance/test_storage_controller_scale.py b/test_runner/performance/test_storage_controller_scale.py index a186bbaceb..452a856714 100644 --- a/test_runner/performance/test_storage_controller_scale.py +++ b/test_runner/performance/test_storage_controller_scale.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import random import time diff --git a/test_runner/performance/test_wal_backpressure.py b/test_runner/performance/test_wal_backpressure.py index c824e60c29..576a4f0467 100644 --- a/test_runner/performance/test_wal_backpressure.py +++ b/test_runner/performance/test_wal_backpressure.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import statistics import threading import time import timeit -from typing import Any, Callable, Generator, List +from collections.abc import Generator +from typing import TYPE_CHECKING import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -13,6 +16,9 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, flush_ep_to_pageserver from performance.test_perf_pgbench import get_durations_matrix, get_scales_matrix +if TYPE_CHECKING: + from typing import Any, Callable + @pytest.fixture(params=["vanilla", "neon_off", "neon_on"]) # This fixture constructs multiple `PgCompare` interfaces using a builder pattern. @@ -202,7 +208,7 @@ def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_inte if not isinstance(env, NeonCompare): return - lsn_write_lags: List[Any] = [] + lsn_write_lags: list[Any] = [] last_received_lsn = Lsn(0) last_pg_flush_lsn = Lsn(0) diff --git a/test_runner/performance/test_write_amplification.py b/test_runner/performance/test_write_amplification.py index 3e290b3996..87824604f8 100644 --- a/test_runner/performance/test_write_amplification.py +++ b/test_runner/performance/test_write_amplification.py @@ -10,6 +10,9 @@ # in LSN order, writing the oldest layer first. That creates a new 10 MB image # layer to be created for each of those small updates. This is the Write # Amplification problem at its finest. + +from __future__ import annotations + from contextlib import closing from fixtures.compare_fixtures import PgCompare diff --git a/test_runner/pg_clients/python/asyncpg/asyncpg_example.py b/test_runner/pg_clients/python/asyncpg/asyncpg_example.py index de86fe482d..9077a07444 100755 --- a/test_runner/pg_clients/python/asyncpg/asyncpg_example.py +++ b/test_runner/pg_clients/python/asyncpg/asyncpg_example.py @@ -1,5 +1,7 @@ #! /usr/bin/env python3 +from __future__ import annotations + import asyncio import os diff --git a/test_runner/pg_clients/python/pg8000/pg8000_example.py b/test_runner/pg_clients/python/pg8000/pg8000_example.py index 840ed97c97..2e92806602 100755 --- a/test_runner/pg_clients/python/pg8000/pg8000_example.py +++ b/test_runner/pg_clients/python/pg8000/pg8000_example.py @@ -1,5 +1,7 @@ #! /usr/bin/env python3 +from __future__ import annotations + import os import pg8000.dbapi diff --git a/test_runner/pg_clients/test_pg_clients.py b/test_runner/pg_clients/test_pg_clients.py index 3579c92b0c..ffa710da06 100644 --- a/test_runner/pg_clients/test_pg_clients.py +++ b/test_runner/pg_clients/test_pg_clients.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import shutil from pathlib import Path from tempfile import NamedTemporaryFile diff --git a/test_runner/regress/test_ancestor_branch.py b/test_runner/regress/test_ancestor_branch.py index 67a38ab471..8cd49d480f 100644 --- a/test_runner/regress/test_ancestor_branch.py +++ b/test_runner/regress/test_ancestor_branch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.common_types import TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index a4e557a863..4a7017994d 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +from collections.abc import Generator from dataclasses import dataclass -from typing import Generator, Optional +from typing import Optional import pytest from fixtures.common_types import TenantId diff --git a/test_runner/regress/test_auth.py b/test_runner/regress/test_auth.py index 6b06092183..eba8197116 100644 --- a/test_runner/regress/test_auth.py +++ b/test_runner/regress/test_auth.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from contextlib import closing from pathlib import Path diff --git a/test_runner/regress/test_aux_files.py b/test_runner/regress/test_aux_files.py index 5328aef156..91d674d0db 100644 --- a/test_runner/regress/test_aux_files.py +++ b/test_runner/regress/test_aux_files.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.log_helper import log from fixtures.neon_fixtures import ( AuxFileStore, diff --git a/test_runner/regress/test_backpressure.py b/test_runner/regress/test_backpressure.py index 3d7a52ca77..c75419b786 100644 --- a/test_runner/regress/test_backpressure.py +++ b/test_runner/regress/test_backpressure.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import time from contextlib import closing, contextmanager diff --git a/test_runner/regress/test_bad_connection.py b/test_runner/regress/test_bad_connection.py index 98842e64f4..c0c9537421 100644 --- a/test_runner/regress/test_bad_connection.py +++ b/test_runner/regress/test_bad_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import time diff --git a/test_runner/regress/test_basebackup_error.py b/test_runner/regress/test_basebackup_error.py index 13c080ea0e..2dd1a88ad7 100644 --- a/test_runner/regress/test_basebackup_error.py +++ b/test_runner/regress/test_basebackup_error.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.neon_fixtures import NeonEnv diff --git a/test_runner/regress/test_branch_and_gc.py b/test_runner/regress/test_branch_and_gc.py index afeea55fc2..6d1565c5e5 100644 --- a/test_runner/regress/test_branch_and_gc.py +++ b/test_runner/regress/test_branch_and_gc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import time diff --git a/test_runner/regress/test_branch_behind.py b/test_runner/regress/test_branch_behind.py index cceb7b3d60..619fc15aa3 100644 --- a/test_runner/regress/test_branch_behind.py +++ b/test_runner/regress/test_branch_behind.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.common_types import Lsn, TimelineId from fixtures.log_helper import log diff --git a/test_runner/regress/test_branching.py b/test_runner/regress/test_branching.py index 8d07dfd511..34e4e994cb 100644 --- a/test_runner/regress/test_branching.py +++ b/test_runner/regress/test_branching.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import random import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import List import pytest from fixtures.common_types import Lsn, TimelineId @@ -56,10 +57,10 @@ def test_branching_with_pgbench( pg_bin.run_capture(["pgbench", "-T15", connstr]) env.create_branch("b0", tenant_id=tenant) - endpoints: List[Endpoint] = [] + endpoints: list[Endpoint] = [] endpoints.append(env.endpoints.create_start("b0", tenant_id=tenant)) - threads: List[threading.Thread] = [] + threads: list[threading.Thread] = [] threads.append( threading.Thread(target=run_pgbench, args=(endpoints[0].connstr(),), daemon=True) ) diff --git a/test_runner/regress/test_broken_timeline.py b/test_runner/regress/test_broken_timeline.py index 6b6af481aa..99e0e23b4a 100644 --- a/test_runner/regress/test_broken_timeline.py +++ b/test_runner/regress/test_broken_timeline.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import concurrent.futures import os -from typing import List, Tuple import pytest from fixtures.common_types import TenantId, TimelineId @@ -31,7 +32,7 @@ def test_local_corruption(neon_env_builder: NeonEnvBuilder): ] ) - tenant_timelines: List[Tuple[TenantId, TimelineId, Endpoint]] = [] + tenant_timelines: list[tuple[TenantId, TimelineId, Endpoint]] = [] for _ in range(3): tenant_id, timeline_id = env.create_tenant() diff --git a/test_runner/regress/test_build_info_metric.py b/test_runner/regress/test_build_info_metric.py index 8f714dae67..9a8744571a 100644 --- a/test_runner/regress/test_build_info_metric.py +++ b/test_runner/regress/test_build_info_metric.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.metrics import parse_metrics from fixtures.neon_fixtures import NeonEnvBuilder, NeonProxy diff --git a/test_runner/regress/test_change_pageserver.py b/test_runner/regress/test_change_pageserver.py index d3aa49f374..41aa5b47ca 100644 --- a/test_runner/regress/test_change_pageserver.py +++ b/test_runner/regress/test_change_pageserver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from fixtures.log_helper import log diff --git a/test_runner/regress/test_clog_truncate.py b/test_runner/regress/test_clog_truncate.py index bfce795d14..10027ce689 100644 --- a/test_runner/regress/test_clog_truncate.py +++ b/test_runner/regress/test_clog_truncate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import time diff --git a/test_runner/regress/test_close_fds.py b/test_runner/regress/test_close_fds.py index 3957d0b3b0..c0bf7d2462 100644 --- a/test_runner/regress/test_close_fds.py +++ b/test_runner/regress/test_close_fds.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os.path import shutil import subprocess @@ -39,9 +41,8 @@ def test_lsof_pageserver_pid(neon_simple_env: NeonEnv): res = subprocess.run( [lsof, path], check=False, - universal_newlines=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + text=True, + capture_output=True, ) # parse the `lsof` command's output to get only the list of commands diff --git a/test_runner/regress/test_combocid.py b/test_runner/regress/test_combocid.py index 41907b1f20..57d5b2d8b3 100644 --- a/test_runner/regress/test_combocid.py +++ b/test_runner/regress/test_combocid.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 98bd3a6a5f..420055ac3a 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import enum import json import os import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -14,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + AGGRESIVE_COMPACTION_TENANT_CONF = { # Disable gc and compaction. The test runs compaction manually. "gc_period": "0s", diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 1f960b6b75..96ba3dd5a4 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import re import shutil @@ -5,8 +7,9 @@ import subprocess import tempfile from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import TYPE_CHECKING +import fixtures.utils import pytest import toml from fixtures.common_types import TenantId, TimelineId @@ -25,6 +28,10 @@ from fixtures.pg_version import PgVersion from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + # # A test suite that help to prevent unintentionally breaking backward or forward compatibility between Neon releases. # - `test_create_snapshot` a script wrapped in a test that creates a data snapshot. @@ -87,6 +94,34 @@ from fixtures.workload import Workload # # Run forward compatibility test # ./scripts/pytest -k test_forward_compatibility # +# +# How to run `test_version_mismatch` locally: +# +# export DEFAULT_PG_VERSION=16 +# export BUILD_TYPE=release +# export CHECK_ONDISK_DATA_COMPATIBILITY=true +# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE} +# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install +# export NEON_BIN=target/release +# export POSTGRES_DISTRIB_DIR=pg_install +# +# # Build previous version of binaries and store them somewhere: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# mkdir -p neon_previous/target +# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE} +# cp -a pg_install ./neon_previous/pg_install +# +# # Build current version of binaries and create a data snapshot: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# ./scripts/pytest -k test_create_snapshot +# +# # Run the version mismatch test +# ./scripts/pytest -k test_version_mismatch + check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif( os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None, @@ -160,16 +195,11 @@ def test_backward_compatibility( neon_env_builder: NeonEnvBuilder, test_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the new binaries can read old data """ - compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") - assert ( - compatibility_snapshot_dir_env is not None - ), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" - compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() - breaking_changes_allowed = ( os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -208,27 +238,11 @@ def test_forward_compatibility( test_output_dir: Path, top_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the old binaries can read new data """ - compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN") - assert compatibility_neon_bin_env is not None, ( - "COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries " - "(ideally generated by the previous version of Neon)" - ) - compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve() - - compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR") - assert ( - compatibility_postgres_distrib_dir_env is not None - ), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)" - compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve() - - compatibility_snapshot_dir = ( - top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" - ) - breaking_changes_allowed = ( os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -239,9 +253,14 @@ def test_forward_compatibility( # Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.). # But always use the current version's neon_local binary. # This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI. - neon_env_builder.neon_binpath = compatibility_neon_bin - neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir - neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath + assert ( + neon_env_builder.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required" + assert ( + neon_env_builder.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required" + neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath + neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir env = neon_env_builder.from_repo_dir( compatibility_snapshot_dir / "repo", @@ -366,7 +385,7 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r def dump_differs( - first: Path, second: Path, output: Path, allowed_diffs: Optional[List[str]] = None + first: Path, second: Path, output: Path, allowed_diffs: Optional[list[str]] = None ) -> bool: """ Runs diff(1) command on two SQL dumps and write the output to the given output file. @@ -552,3 +571,29 @@ def test_historic_storage_formats( env.pageserver.http_client().timeline_compact( dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True ) + + +@check_ondisk_data_compatibility_if_enabled +@pytest.mark.xdist_group("compatibility") +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_versions_mismatch( + neon_env_builder: NeonEnvBuilder, + test_output_dir: Path, + pg_version: PgVersion, + compatibility_snapshot_dir, + combination, +): + """ + Checks compatibility of different combinations of versions of the components + """ + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.from_repo_dir( + compatibility_snapshot_dir / "repo", + ) + env.pageserver.allowed_errors.extend( + [".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"] + ) + env.start() + check_neon_works( + env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo" + ) diff --git a/test_runner/regress/test_compute_catalog.py b/test_runner/regress/test_compute_catalog.py index 8b8c970357..d43c71ceac 100644 --- a/test_runner/regress/test_compute_catalog.py +++ b/test_runner/regress/test_compute_catalog.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import requests from fixtures.neon_fixtures import NeonEnv diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index 6138c322d7..6c75765632 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnv diff --git a/test_runner/regress/test_config.py b/test_runner/regress/test_config.py index 5aba1f265f..d48fd01fcb 100644 --- a/test_runner/regress/test_config.py +++ b/test_runner/regress/test_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from contextlib import closing @@ -66,7 +68,7 @@ def test_safekeepers_reconfigure_reorder( assert new_sks != old_sks, "GUC changes were applied" log_path = os.path.join(endpoint.endpoint_path(), "compute.log") - with open(log_path, "r") as log_file: + with open(log_path) as log_file: logs = log_file.read() # Check that walproposer was not restarted assert "restarting walproposer" not in logs diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 71369ab131..23c6fa3a5a 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.log_helper import log from fixtures.neon_cli import WalCraft diff --git a/test_runner/regress/test_createdropdb.py b/test_runner/regress/test_createdropdb.py index cdf048ac26..97e185ceb5 100644 --- a/test_runner/regress/test_createdropdb.py +++ b/test_runner/regress/test_createdropdb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pathlib diff --git a/test_runner/regress/test_createuser.py b/test_runner/regress/test_createuser.py index 96b38f8fb0..236f4eb2fe 100644 --- a/test_runner/regress/test_createuser.py +++ b/test_runner/regress/test_createuser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnv from fixtures.utils import query_scalar diff --git a/test_runner/regress/test_ddl_forwarding.py b/test_runner/regress/test_ddl_forwarding.py index 65f310c27a..96657b3ce4 100644 --- a/test_runner/regress/test_ddl_forwarding.py +++ b/test_runner/regress/test_ddl_forwarding.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING import psycopg2 import pytest @@ -9,6 +11,9 @@ from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Any, Optional + def handle_db(dbs, roles, operation): if operation["op"] == "set": @@ -43,7 +48,7 @@ def handle_role(dbs, roles, operation): def ddl_forward_handler( - request: Request, dbs: Dict[str, str], roles: Dict[str, str], ddl: "DdlForwardingContext" + request: Request, dbs: dict[str, str], roles: dict[str, str], ddl: DdlForwardingContext ) -> Response: log.info(f"Received request with data {request.get_data(as_text=True)}") if ddl.fail: @@ -69,8 +74,8 @@ class DdlForwardingContext: self.pg = vanilla_pg self.host = host self.port = port - self.dbs: Dict[str, str] = {} - self.roles: Dict[str, str] = {} + self.dbs: dict[str, str] = {} + self.roles: dict[str, str] = {} self.fail = False endpoint = "/test/roles_and_databases" ddl_url = f"http://{host}:{port}{endpoint}" @@ -91,13 +96,13 @@ class DdlForwardingContext: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): self.pg.stop() - def send(self, query: str) -> List[Tuple[Any, ...]]: + def send(self, query: str) -> list[tuple[Any, ...]]: return self.pg.safe_psql(query) def wait(self, timeout=3): @@ -106,7 +111,7 @@ class DdlForwardingContext: def failures(self, bool): self.fail = bool - def send_and_wait(self, query: str, timeout=3) -> List[Tuple[Any, ...]]: + def send_and_wait(self, query: str, timeout=3) -> list[tuple[Any, ...]]: res = self.send(query) self.wait(timeout=timeout) return res diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index 4fcdef0ca3..72866766de 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import enum import time from collections import Counter +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Dict, Iterable, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId @@ -19,6 +22,10 @@ from fixtures.pageserver.utils import wait_for_upload_queue_empty from fixtures.remote_storage import RemoteStorageKind from fixtures.utils import human_bytes, wait_until +if TYPE_CHECKING: + from typing import Any + + GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy" # access times in the pageserver are stored at a very low resolution: to generate meaningfully different @@ -74,7 +81,7 @@ class EvictionOrder(str, enum.Enum): RELATIVE_ORDER_EQUAL = "relative_equal" RELATIVE_ORDER_SPARE = "relative_spare" - def config(self) -> Dict[str, Any]: + def config(self) -> dict[str, Any]: if self == EvictionOrder.RELATIVE_ORDER_EQUAL: return { "type": "RelativeAccessed", @@ -91,12 +98,12 @@ class EvictionOrder(str, enum.Enum): @dataclass class EvictionEnv: - timelines: list[Tuple[TenantId, TimelineId]] + timelines: list[tuple[TenantId, TimelineId]] neon_env: NeonEnv pg_bin: PgBin pageserver_http: PageserverHttpClient layer_size: int - pgbench_init_lsns: Dict[TenantId, Lsn] + pgbench_init_lsns: dict[TenantId, Lsn] @property def pageserver(self): @@ -105,7 +112,7 @@ class EvictionEnv: """ return self.neon_env.pageserver - def timelines_du(self, pageserver: NeonPageserver) -> Tuple[int, int, int]: + def timelines_du(self, pageserver: NeonPageserver) -> tuple[int, int, int]: return poor_mans_du( self.neon_env, [(tid, tlid) for tid, tlid in self.timelines], @@ -113,13 +120,13 @@ class EvictionEnv: verbose=False, ) - def du_by_timeline(self, pageserver: NeonPageserver) -> Dict[Tuple[TenantId, TimelineId], int]: + def du_by_timeline(self, pageserver: NeonPageserver) -> dict[tuple[TenantId, TimelineId], int]: return { (tid, tlid): poor_mans_du(self.neon_env, [(tid, tlid)], pageserver, verbose=True)[0] for tid, tlid in self.timelines } - def count_layers_per_tenant(self, pageserver: NeonPageserver) -> Dict[TenantId, int]: + def count_layers_per_tenant(self, pageserver: NeonPageserver) -> dict[TenantId, int]: return count_layers_per_tenant(pageserver, self.timelines) def warm_up_tenant(self, tenant_id: TenantId): @@ -204,8 +211,8 @@ class EvictionEnv: def count_layers_per_tenant( - pageserver: NeonPageserver, timelines: Iterable[Tuple[TenantId, TimelineId]] -) -> Dict[TenantId, int]: + pageserver: NeonPageserver, timelines: Iterable[tuple[TenantId, TimelineId]] +) -> dict[TenantId, int]: ret: Counter[TenantId] = Counter() for tenant_id, timeline_id in timelines: @@ -279,7 +286,7 @@ def _eviction_env( def pgbench_init_tenant( layer_size: int, scale: int, env: NeonEnv, pg_bin: PgBin -) -> Tuple[TenantId, TimelineId]: +) -> tuple[TenantId, TimelineId]: tenant_id, timeline_id = env.create_tenant( conf={ "gc_period": "0s", @@ -672,10 +679,10 @@ def test_fast_growing_tenant(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin, or def poor_mans_du( env: NeonEnv, - timelines: Iterable[Tuple[TenantId, TimelineId]], + timelines: Iterable[tuple[TenantId, TimelineId]], pageserver: NeonPageserver, verbose: bool = False, -) -> Tuple[int, int, int]: +) -> tuple[int, int, int]: """ Disk usage, largest, smallest layer for layer files over the given (tenant, timeline) tuples; this could be done over layers endpoint just as well. diff --git a/test_runner/regress/test_download_extensions.py b/test_runner/regress/test_download_extensions.py index c89a82965e..04916a6b6f 100644 --- a/test_runner/regress/test_download_extensions.py +++ b/test_runner/regress/test_download_extensions.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import os import shutil from contextlib import closing from pathlib import Path -from typing import Any, Dict +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -14,6 +16,9 @@ from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Any + # use neon_env_builder_local fixture to override the default neon_env_builder fixture # and use a test-specific pg_install instead of shared one @@ -88,7 +93,7 @@ def test_remote_extensions( ) # mock remote_extensions spec - spec: Dict[str, Any] = { + spec: dict[str, Any] = { "library_index": { "anon": "anon", }, diff --git a/test_runner/regress/test_endpoint_crash.py b/test_runner/regress/test_endpoint_crash.py index e34dfab6c4..0217cd0d03 100644 --- a/test_runner/regress/test_endpoint_crash.py +++ b/test_runner/regress/test_endpoint_crash.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_explain_with_lfc_stats.py b/test_runner/regress/test_explain_with_lfc_stats.py index 0217c9ac7b..2128bd93dd 100644 --- a/test_runner/regress/test_explain_with_lfc_stats.py +++ b/test_runner/regress/test_explain_with_lfc_stats.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from fixtures.log_helper import log diff --git a/test_runner/regress/test_fsm_truncate.py b/test_runner/regress/test_fsm_truncate.py index 691f96ab0a..55a010f26a 100644 --- a/test_runner/regress/test_fsm_truncate.py +++ b/test_runner/regress/test_fsm_truncate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_fullbackup.py b/test_runner/regress/test_fullbackup.py index e6d51a77a6..62d59528cf 100644 --- a/test_runner/regress/test_fullbackup.py +++ b/test_runner/regress/test_fullbackup.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from pathlib import Path diff --git a/test_runner/regress/test_gc_aggressive.py b/test_runner/regress/test_gc_aggressive.py index 3d472f9720..97c38cf658 100644 --- a/test_runner/regress/test_gc_aggressive.py +++ b/test_runner/regress/test_gc_aggressive.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import concurrent.futures import random diff --git a/test_runner/regress/test_gin_redo.py b/test_runner/regress/test_gin_redo.py index 9205882239..71382990dc 100644 --- a/test_runner/regress/test_gin_redo.py +++ b/test_runner/regress/test_gin_redo.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from fixtures.neon_fixtures import NeonEnv, wait_replica_caughtup diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index be8f70bb70..a906e7a243 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os import threading diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 87b44e4e3e..e367db33ff 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import re diff --git a/test_runner/regress/test_ingestion_layer_size.py b/test_runner/regress/test_ingestion_layer_size.py index 44c77b3410..2edbf4d6d3 100644 --- a/test_runner/regress/test_ingestion_layer_size.py +++ b/test_runner/regress/test_ingestion_layer_size.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable, List, Union +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -7,6 +10,9 @@ from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_last_flush_lsn from fixtures.pageserver.http import HistoricLayerInfo, LayerMapInfo from fixtures.utils import human_bytes +if TYPE_CHECKING: + from typing import Union + def test_ingesting_large_batches_of_images(neon_env_builder: NeonEnvBuilder, build_type: str): """ @@ -106,13 +112,13 @@ def test_ingesting_large_batches_of_images(neon_env_builder: NeonEnvBuilder, bui @dataclass class Histogram: - buckets: List[Union[int, float]] - counts: List[int] - sums: List[int] + buckets: list[Union[int, float]] + counts: list[int] + sums: list[int] def histogram_historic_layers( - infos: LayerMapInfo, minimum_sizes: List[Union[int, float]] + infos: LayerMapInfo, minimum_sizes: list[Union[int, float]] ) -> Histogram: def log_layer(layer: HistoricLayerInfo) -> HistoricLayerInfo: log.info( @@ -125,7 +131,7 @@ def histogram_historic_layers( return histogram(sizes, minimum_sizes) -def histogram(sizes: Iterable[int], minimum_sizes: List[Union[int, float]]) -> Histogram: +def histogram(sizes: Iterable[int], minimum_sizes: list[Union[int, float]]) -> Histogram: assert all(minimum_sizes[i] < minimum_sizes[i + 1] for i in range(len(minimum_sizes) - 1)) buckets = list(enumerate(minimum_sizes)) counts = [0 for _ in buckets] diff --git a/test_runner/regress/test_installed_extensions.py b/test_runner/regress/test_installed_extensions.py new file mode 100644 index 0000000000..4700db85ee --- /dev/null +++ b/test_runner/regress/test_installed_extensions.py @@ -0,0 +1,87 @@ +from logging import info + +from fixtures.neon_fixtures import NeonEnv + + +def test_installed_extensions(neon_simple_env: NeonEnv): + """basic test for the endpoint that returns the list of installed extensions""" + + env = neon_simple_env + + env.create_branch("test_installed_extensions") + + endpoint = env.endpoints.create_start("test_installed_extensions") + + endpoint.safe_psql("CREATE DATABASE test_installed_extensions") + endpoint.safe_psql("CREATE DATABASE test_installed_extensions_2") + + client = endpoint.http_client() + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + # 'plpgsql' is a default extension that is always installed. + assert any( + ext["extname"] == "plpgsql" and ext["versions"] == ["1.0"] for ext in res["extensions"] + ), "The 'plpgsql' extension is missing" + + # check that the neon_test_utils extension is not installed + assert not any( + ext["extname"] == "neon_test_utils" for ext in res["extensions"] + ), "The 'neon_test_utils' extension is installed" + + pg_conn = endpoint.connect(dbname="test_installed_extensions") + with pg_conn.cursor() as cur: + cur.execute("CREATE EXTENSION neon_test_utils") + cur.execute( + "SELECT default_version FROM pg_available_extensions WHERE name = 'neon_test_utils'" + ) + res = cur.fetchone() + neon_test_utils_version = res[0] + + with pg_conn.cursor() as cur: + cur.execute("CREATE EXTENSION neon version '1.1'") + + pg_conn_2 = endpoint.connect(dbname="test_installed_extensions_2") + with pg_conn_2.cursor() as cur: + cur.execute("CREATE EXTENSION neon version '1.2'") + + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + + # check that the neon_test_utils extension is installed only in 1 database + # and has the expected version + assert any( + ext["extname"] == "neon_test_utils" + and ext["versions"] == [neon_test_utils_version] + and ext["n_databases"] == 1 + for ext in res["extensions"] + ) + + # check that the plpgsql extension is installed in all databases + # this is a default extension that is always installed + assert any(ext["extname"] == "plpgsql" and ext["n_databases"] == 4 for ext in res["extensions"]) + + # check that the neon extension is installed and has expected versions + for ext in res["extensions"]: + if ext["extname"] == "neon": + assert ext["n_databases"] == 2 + ext["versions"].sort() + assert ext["versions"] == ["1.1", "1.2"] + + with pg_conn.cursor() as cur: + cur.execute("ALTER EXTENSION neon UPDATE TO '1.3'") + + res = client.installed_extensions() + + info("Extensions list: %s", res) + info("Extensions: %s", res["extensions"]) + + # check that the neon_test_utils extension is updated + for ext in res["extensions"]: + if ext["extname"] == "neon": + assert ext["n_databases"] == 2 + ext["versions"].sort() + assert ext["versions"] == ["1.2", "1.3"] diff --git a/test_runner/regress/test_large_schema.py b/test_runner/regress/test_large_schema.py index c5d5b5fe64..ae5113ed45 100644 --- a/test_runner/regress/test_large_schema.py +++ b/test_runner/regress/test_large_schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import time diff --git a/test_runner/regress/test_layer_bloating.py b/test_runner/regress/test_layer_bloating.py index b8126395fd..a08d522fc2 100644 --- a/test_runner/regress/test_layer_bloating.py +++ b/test_runner/regress/test_layer_bloating.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pytest diff --git a/test_runner/regress/test_layer_eviction.py b/test_runner/regress/test_layer_eviction.py index 82cfe08bc0..c49ac6893e 100644 --- a/test_runner/regress/test_layer_eviction.py +++ b/test_runner/regress/test_layer_eviction.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time import pytest diff --git a/test_runner/regress/test_layer_writers_fail.py b/test_runner/regress/test_layer_writers_fail.py index 1711cc1414..dd31e2725b 100644 --- a/test_runner/regress/test_layer_writers_fail.py +++ b/test_runner/regress/test_layer_writers_fail.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.neon_fixtures import NeonEnv, NeonPageserver from fixtures.pageserver.http import PageserverApiException diff --git a/test_runner/regress/test_layers_from_future.py b/test_runner/regress/test_layers_from_future.py index 2857df8ef7..2536ec1b3c 100644 --- a/test_runner/regress/test_layers_from_future.py +++ b/test_runner/regress/test_layers_from_future.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from fixtures.common_types import Lsn diff --git a/test_runner/regress/test_lfc_resize.py b/test_runner/regress/test_lfc_resize.py index 0f791e9247..3083128d87 100644 --- a/test_runner/regress/test_lfc_resize.py +++ b/test_runner/regress/test_lfc_resize.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import random import re diff --git a/test_runner/regress/test_lfc_working_set_approximation.py b/test_runner/regress/test_lfc_working_set_approximation.py index 4a3a949d1a..36dfec969f 100644 --- a/test_runner/regress/test_lfc_working_set_approximation.py +++ b/test_runner/regress/test_lfc_working_set_approximation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from pathlib import Path diff --git a/test_runner/regress/test_local_file_cache.py b/test_runner/regress/test_local_file_cache.py index 9c38200937..fbf018a167 100644 --- a/test_runner/regress/test_local_file_cache.py +++ b/test_runner/regress/test_local_file_cache.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import os import queue import random import threading import time -from typing import List from fixtures.neon_fixtures import NeonEnvBuilder from fixtures.utils import query_scalar @@ -57,7 +58,7 @@ def test_local_file_cache_unlink(neon_env_builder: NeonEnvBuilder): n_updates_performed_q.put(n_updates_performed) n_updates_performed_q: queue.Queue[int] = queue.Queue() - threads: List[threading.Thread] = [] + threads: list[threading.Thread] = [] for _i in range(n_threads): thread = threading.Thread(target=run_updates, args=(n_updates_performed_q,), daemon=True) thread.start() diff --git a/test_runner/regress/test_logging.py b/test_runner/regress/test_logging.py index bfffad7572..9a3fdd835d 100644 --- a/test_runner/regress/test_logging.py +++ b/test_runner/regress/test_logging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid import pytest diff --git a/test_runner/regress/test_logical_replication.py b/test_runner/regress/test_logical_replication.py index 1aa1bdf366..87991eadf1 100644 --- a/test_runner/regress/test_logical_replication.py +++ b/test_runner/regress/test_logical_replication.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from functools import partial from random import choice @@ -336,7 +338,7 @@ FROM generate_series(1, 16384) AS seq; -- Inserts enough rows to exceed 16MB of assert [r[0] for r in vanilla_pg.safe_psql("select * from t")] == [1, 2, 3] log_path = vanilla_pg.pgdatadir / "pg.log" - with open(log_path, "r") as log_file: + with open(log_path) as log_file: logs = log_file.read() assert "could not receive data from WAL stream" not in logs diff --git a/test_runner/regress/test_lsn_mapping.py b/test_runner/regress/test_lsn_mapping.py index ab43e32146..8b41d0cb1c 100644 --- a/test_runner/regress/test_lsn_mapping.py +++ b/test_runner/regress/test_lsn_mapping.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import time from concurrent.futures import ThreadPoolExecutor diff --git a/test_runner/regress/test_multixact.py b/test_runner/regress/test_multixact.py index 742d03e464..e8bbe5aa97 100644 --- a/test_runner/regress/test_multixact.py +++ b/test_runner/regress/test_multixact.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content from fixtures.utils import query_scalar diff --git a/test_runner/regress/test_neon_cli.py b/test_runner/regress/test_neon_cli.py index 04780ebcf1..783fb813cf 100644 --- a/test_runner/regress/test_neon_cli.py +++ b/test_runner/regress/test_neon_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess from pathlib import Path @@ -160,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder): env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID) env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1) + # We will stop the storage controller while it may have requests in + # flight, and the pageserver complains when requests are abandoned. + for ps in env.pageservers: + ps.allowed_errors.append(".*request was dropped before completing.*") + # Keep NeonEnv state up to date, it usually owns starting/stopping services env.pageservers[0].running = False env.pageservers[1].running = False diff --git a/test_runner/regress/test_neon_extension.py b/test_runner/regress/test_neon_extension.py index a99e9e15af..4035398a5f 100644 --- a/test_runner/regress/test_neon_extension.py +++ b/test_runner/regress/test_neon_extension.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from contextlib import closing diff --git a/test_runner/regress/test_neon_local_cli.py b/test_runner/regress/test_neon_local_cli.py index 0fdc5960e3..80e26d9432 100644 --- a/test_runner/regress/test_neon_local_cli.py +++ b/test_runner/regress/test_neon_local_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.common_types import TimelineId from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_neon_superuser.py b/test_runner/regress/test_neon_superuser.py index dc1c9d3fd9..7118127a1f 100644 --- a/test_runner/regress/test_neon_superuser.py +++ b/test_runner/regress/test_neon_superuser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv from fixtures.pg_version import PgVersion diff --git a/test_runner/regress/test_next_xid.py b/test_runner/regress/test_next_xid.py index cac74492d7..980f6b5694 100644 --- a/test_runner/regress/test_next_xid.py +++ b/test_runner/regress/test_next_xid.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import time from pathlib import Path @@ -189,7 +191,7 @@ def test_import_at_2bil( # calculate the SLRU segments that a particular multixid or multixid-offsets falls into. BLCKSZ = 8192 MULTIXACT_OFFSETS_PER_PAGE = int(BLCKSZ / 4) -SLRU_PAGES_PER_SEGMENT = int(32) +SLRU_PAGES_PER_SEGMENT = 32 MXACT_MEMBER_BITS_PER_XACT = 8 MXACT_MEMBER_FLAGS_PER_BYTE = 1 MULTIXACT_FLAGBYTES_PER_GROUP = 4 diff --git a/test_runner/regress/test_normal_work.py b/test_runner/regress/test_normal_work.py index 54433769fd..ae2d171058 100644 --- a/test_runner/regress/test_normal_work.py +++ b/test_runner/regress/test_normal_work.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder diff --git a/test_runner/regress/test_oid_overflow.py b/test_runner/regress/test_oid_overflow.py index e8eefc2414..f69c1112c7 100644 --- a/test_runner/regress/test_oid_overflow.py +++ b/test_runner/regress/test_oid_overflow.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_old_request_lsn.py b/test_runner/regress/test_old_request_lsn.py index dfd0271c10..a615464582 100644 --- a/test_runner/regress/test_old_request_lsn.py +++ b/test_runner/regress/test_old_request_lsn.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.common_types import TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_ondemand_download.py b/test_runner/regress/test_ondemand_download.py index 0d712d06f1..e1caaeb6c1 100644 --- a/test_runner/regress/test_ondemand_download.py +++ b/test_runner/regress/test_ondemand_download.py @@ -1,10 +1,12 @@ # It's possible to run any regular test with the local fs remote storage via # env ZENITH_PAGESERVER_OVERRIDES="remote_storage={local_path='/tmp/neon_zzz/'}" poetry ...... +from __future__ import annotations + import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, DefaultDict, Dict, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn @@ -26,6 +28,9 @@ from fixtures.pageserver.utils import ( from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage from fixtures.utils import query_scalar, wait_until +if TYPE_CHECKING: + from typing import Any + def get_num_downloaded_layers(client: PageserverHttpClient): """ @@ -505,7 +510,7 @@ def test_compaction_downloads_on_demand_without_image_creation(neon_env_builder: env = neon_env_builder.init_start(initial_tenant_conf=stringify(conf)) - def downloaded_bytes_and_count(pageserver_http: PageserverHttpClient) -> Tuple[int, int]: + def downloaded_bytes_and_count(pageserver_http: PageserverHttpClient) -> tuple[int, int]: m = pageserver_http.get_metrics() # these are global counters total_bytes = m.query_one("pageserver_remote_ondemand_downloaded_bytes_total").value @@ -634,7 +639,7 @@ def test_compaction_downloads_on_demand_with_image_creation(neon_env_builder: Ne layers = pageserver_http.layer_map_info(tenant_id, timeline_id) assert not layers.in_memory_layers, "no inmemory layers expected after post-commit checkpoint" - kinds_before: DefaultDict[str, int] = defaultdict(int) + kinds_before: defaultdict[str, int] = defaultdict(int) for layer in layers.historic_layers: kinds_before[layer.kind] += 1 @@ -651,7 +656,7 @@ def test_compaction_downloads_on_demand_with_image_creation(neon_env_builder: Ne pageserver_http.timeline_compact(tenant_id, timeline_id) layers = pageserver_http.layer_map_info(tenant_id, timeline_id) - kinds_after: DefaultDict[str, int] = defaultdict(int) + kinds_after: defaultdict[str, int] = defaultdict(int) for layer in layers.historic_layers: kinds_after[layer.kind] += 1 @@ -855,5 +860,5 @@ def test_layer_download_timeouted(neon_env_builder: NeonEnvBuilder): assert elapsed < 30, "too long passed: {elapsed=}" -def stringify(conf: Dict[str, Any]) -> Dict[str, str]: +def stringify(conf: dict[str, Any]) -> dict[str, str]: return dict(map(lambda x: (x[0], str(x[1])), conf.items())) diff --git a/test_runner/regress/test_ondemand_slru_download.py b/test_runner/regress/test_ondemand_slru_download.py index d6babe4393..5eaba78331 100644 --- a/test_runner/regress/test_ondemand_slru_download.py +++ b/test_runner/regress/test_ondemand_slru_download.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional import pytest diff --git a/test_runner/regress/test_pageserver_api.py b/test_runner/regress/test_pageserver_api.py index a19bc785f8..d1b70b9ee6 100644 --- a/test_runner/regress/test_pageserver_api.py +++ b/test_runner/regress/test_pageserver_api.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional from fixtures.common_types import Lsn, TenantId, TimelineId diff --git a/test_runner/regress/test_pageserver_catchup.py b/test_runner/regress/test_pageserver_catchup.py index d020104431..3567e05f81 100644 --- a/test_runner/regress/test_pageserver_catchup.py +++ b/test_runner/regress/test_pageserver_catchup.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_pageserver_crash_consistency.py b/test_runner/regress/test_pageserver_crash_consistency.py index 2d6b50490e..ac46d3e62a 100644 --- a/test_runner/regress/test_pageserver_crash_consistency.py +++ b/test_runner/regress/test_pageserver_crash_consistency.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, wait_for_last_flush_lsn from fixtures.pageserver.common_types import ImageLayerName, parse_layer_file_name diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index a135b3da1a..11ebb81023 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -9,11 +9,13 @@ of the pageserver are: - Updates to remote_consistent_lsn may only be made visible after validating generation """ +from __future__ import annotations + import enum import os import re import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TimelineId @@ -38,6 +40,10 @@ from fixtures.remote_storage import ( from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { @@ -664,14 +670,17 @@ def test_upgrade_generationless_local_file_paths( pageserver.stop() timeline_dir = pageserver.timeline_dir(tenant_id, timeline_id) files_renamed = 0 + log.info(f"Renaming files in {timeline_dir}") for filename in os.listdir(timeline_dir): - path = os.path.join(timeline_dir, filename) - log.info(f"Found file {path}") - if path.endswith("-v1-00000001"): - new_path = path[:-12] - os.rename(path, new_path) - log.info(f"Renamed {path} -> {new_path}") + if filename.endswith("-v1-00000001"): + new_filename = filename[:-12] + os.rename( + os.path.join(timeline_dir, filename), os.path.join(timeline_dir, new_filename) + ) + log.info(f"Renamed {filename} -> {new_filename}") files_renamed += 1 + else: + log.info(f"Keeping {filename}") assert files_renamed > 0 diff --git a/test_runner/regress/test_pageserver_getpage_throttle.py b/test_runner/regress/test_pageserver_getpage_throttle.py index 4c9eac5cd7..6811d09cff 100644 --- a/test_runner/regress/test_pageserver_getpage_throttle.py +++ b/test_runner/regress/test_pageserver_getpage_throttle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import json import uuid @@ -61,7 +63,7 @@ def test_pageserver_getpage_throttle(neon_env_builder: NeonEnvBuilder, pg_bin: P results_path = Path(basepath + ".stdout") log.info(f"Benchmark results at: {results_path}") - with open(results_path, "r") as f: + with open(results_path) as f: results = json.load(f) log.info(f"Results:\n{json.dumps(results, sort_keys=True, indent=2)}") return int(results["total"]["request_count"]) diff --git a/test_runner/regress/test_pageserver_layer_rolling.py b/test_runner/regress/test_pageserver_layer_rolling.py index 8c6e563357..c0eb598891 100644 --- a/test_runner/regress/test_pageserver_layer_rolling.py +++ b/test_runner/regress/test_pageserver_layer_rolling.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio import os import time -from typing import Optional, Tuple +from typing import TYPE_CHECKING import psutil import pytest @@ -16,6 +18,10 @@ from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload from fixtures.utils import wait_until +if TYPE_CHECKING: + from typing import Optional + + TIMELINE_COUNT = 10 ENTRIES_PER_TIMELINE = 10_000 CHECKPOINT_TIMEOUT_SECONDS = 60 @@ -41,7 +47,7 @@ async def run_worker_for_tenant( return last_flush_lsn -async def run_worker(env: NeonEnv, tenant_conf, entries: int) -> Tuple[TenantId, TimelineId, Lsn]: +async def run_worker(env: NeonEnv, tenant_conf, entries: int) -> tuple[TenantId, TimelineId, Lsn]: tenant, timeline = env.create_tenant(conf=tenant_conf) last_flush_lsn = await run_worker_for_tenant(env, entries, tenant) return tenant, timeline, last_flush_lsn @@ -49,13 +55,13 @@ async def run_worker(env: NeonEnv, tenant_conf, entries: int) -> Tuple[TenantId, async def workload( env: NeonEnv, tenant_conf, timelines: int, entries: int -) -> list[Tuple[TenantId, TimelineId, Lsn]]: +) -> list[tuple[TenantId, TimelineId, Lsn]]: workers = [asyncio.create_task(run_worker(env, tenant_conf, entries)) for _ in range(timelines)] return await asyncio.gather(*workers) def wait_until_pageserver_is_caught_up( - env: NeonEnv, last_flush_lsns: list[Tuple[TenantId, TimelineId, Lsn]] + env: NeonEnv, last_flush_lsns: list[tuple[TenantId, TimelineId, Lsn]] ): for tenant, timeline, last_flush_lsn in last_flush_lsns: shards = tenant_get_shards(env, tenant) @@ -67,7 +73,7 @@ def wait_until_pageserver_is_caught_up( def wait_until_pageserver_has_uploaded( - env: NeonEnv, last_flush_lsns: list[Tuple[TenantId, TimelineId, Lsn]] + env: NeonEnv, last_flush_lsns: list[tuple[TenantId, TimelineId, Lsn]] ): for tenant, timeline, last_flush_lsn in last_flush_lsns: shards = tenant_get_shards(env, tenant) diff --git a/test_runner/regress/test_pageserver_metric_collection.py b/test_runner/regress/test_pageserver_metric_collection.py index 37ab51f9fb..5ec8357597 100644 --- a/test_runner/regress/test_pageserver_metric_collection.py +++ b/test_runner/regress/test_pageserver_metric_collection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gzip import json import os @@ -5,7 +7,7 @@ import time from dataclasses import dataclass from pathlib import Path from queue import SimpleQueue -from typing import Any, Dict, Set +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log @@ -22,6 +24,10 @@ from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Any + + # TODO: collect all of the env setup *AFTER* removal of RemoteStorageKind.NOOP @@ -308,8 +314,8 @@ def test_metric_collection_cleans_up_tempfile( @dataclass class PrefixPartitionedFiles: - matching: Set[str] - other: Set[str] + matching: set[str] + other: set[str] def iterate_pageserver_workdir(path: Path, prefix: str) -> PrefixPartitionedFiles: @@ -340,7 +346,7 @@ class MetricsVerifier: """ def __init__(self): - self.tenants: Dict[TenantId, TenantMetricsVerifier] = {} + self.tenants: dict[TenantId, TenantMetricsVerifier] = {} pass def ingest(self, events, is_last): @@ -357,8 +363,8 @@ class MetricsVerifier: for t in self.tenants.values(): t.post_batch() - def accepted_event_names(self) -> Set[str]: - names: Set[str] = set() + def accepted_event_names(self) -> set[str]: + names: set[str] = set() for t in self.tenants.values(): names = names.union(t.accepted_event_names()) return names @@ -367,8 +373,8 @@ class MetricsVerifier: class TenantMetricsVerifier: def __init__(self, id: TenantId): self.id = id - self.timelines: Dict[TimelineId, TimelineMetricsVerifier] = {} - self.state: Dict[str, Any] = {} + self.timelines: dict[TimelineId, TimelineMetricsVerifier] = {} + self.state: dict[str, Any] = {} def ingest(self, event): assert TenantId(event["tenant_id"]) == self.id @@ -392,7 +398,7 @@ class TenantMetricsVerifier: for tl in self.timelines.values(): tl.post_batch(self) - def accepted_event_names(self) -> Set[str]: + def accepted_event_names(self) -> set[str]: names = set(self.state.keys()) for t in self.timelines.values(): names = names.union(t.accepted_event_names()) @@ -402,7 +408,7 @@ class TenantMetricsVerifier: class TimelineMetricsVerifier: def __init__(self, tenant_id: TenantId, timeline_id: TimelineId): self.id = timeline_id - self.state: Dict[str, Any] = {} + self.state: dict[str, Any] = {} def ingest(self, event): name = event["metric"] @@ -414,7 +420,7 @@ class TimelineMetricsVerifier: for v in self.state.values(): v.post_batch(self) - def accepted_event_names(self) -> Set[str]: + def accepted_event_names(self) -> set[str]: return set(self.state.keys()) diff --git a/test_runner/regress/test_pageserver_reconnect.py b/test_runner/regress/test_pageserver_reconnect.py index 7f10c36db8..be63208428 100644 --- a/test_runner/regress/test_pageserver_reconnect.py +++ b/test_runner/regress/test_pageserver_reconnect.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import time from contextlib import closing diff --git a/test_runner/regress/test_pageserver_restart.py b/test_runner/regress/test_pageserver_restart.py index 86313ca91e..f7c42fc893 100644 --- a/test_runner/regress/test_pageserver_restart.py +++ b/test_runner/regress/test_pageserver_restart.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random from contextlib import closing from typing import Optional diff --git a/test_runner/regress/test_pageserver_restarts_under_workload.py b/test_runner/regress/test_pageserver_restarts_under_workload.py index 637e1a87d3..ec74e03f89 100644 --- a/test_runner/regress/test_pageserver_restarts_under_workload.py +++ b/test_runner/regress/test_pageserver_restarts_under_workload.py @@ -1,6 +1,9 @@ # This test spawns pgbench in a thread in the background and concurrently restarts pageserver, # checking how client is able to transparently restore connection to pageserver # + +from __future__ import annotations + import threading import time diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index cd772beace..705b4ff054 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import os import random import time from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -20,6 +22,10 @@ from fixtures.workload import Workload from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Any, Optional, Union + + # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { @@ -193,11 +199,11 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, # state if it was running attached with a stale generation last_state[pageserver.id] = ("Detached", None) else: - secondary_conf: Optional[Dict[str, Any]] = None + secondary_conf: Optional[dict[str, Any]] = None if mode == "Secondary": secondary_conf = {"warm": rng.choice([True, False])} - location_conf: Dict[str, Any] = { + location_conf: dict[str, Any] = { "mode": mode, "secondary_conf": secondary_conf, "tenant_conf": {}, diff --git a/test_runner/regress/test_parallel_copy.py b/test_runner/regress/test_parallel_copy.py index a5037e8694..1689755b6f 100644 --- a/test_runner/regress/test_parallel_copy.py +++ b/test_runner/regress/test_parallel_copy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from io import BytesIO diff --git a/test_runner/regress/test_pg_query_cancellation.py b/test_runner/regress/test_pg_query_cancellation.py index c6b4eff516..d4ed7230fa 100644 --- a/test_runner/regress/test_pg_query_cancellation.py +++ b/test_runner/regress/test_pg_query_cancellation.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from contextlib import closing -from typing import Set import pytest from fixtures.log_helper import log @@ -7,7 +8,7 @@ from fixtures.neon_fixtures import Endpoint, NeonEnv, NeonPageserver from fixtures.pageserver.http import PageserverHttpClient from psycopg2.errors import QueryCanceled -CRITICAL_PG_PS_WAIT_FAILPOINTS: Set[str] = { +CRITICAL_PG_PS_WAIT_FAILPOINTS: set[str] = { "ps::connection-start::pre-login", "ps::connection-start::startup-packet", "ps::connection-start::process-query", @@ -92,7 +93,7 @@ def test_cancellations(neon_simple_env: NeonEnv): connect_works_correctly(failpoint, ep, ps, ps_http) -ENABLED_FAILPOINTS: Set[str] = set() +ENABLED_FAILPOINTS: set[str] = set() def connect_works_correctly( diff --git a/test_runner/regress/test_pg_waldump.py b/test_runner/regress/test_pg_waldump.py index 1990d69b6a..c98d395451 100644 --- a/test_runner/regress/test_pg_waldump.py +++ b/test_runner/regress/test_pg_waldump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil @@ -12,7 +14,7 @@ def check_wal_segment(pg_waldump_path: str, segment_path: str, test_output_dir): test_output_dir, [pg_waldump_path, "--ignore", segment_path] ) - with open(f"{output_path}.stdout", "r") as f: + with open(f"{output_path}.stdout") as f: stdout = f.read() assert "ABORT" in stdout assert "COMMIT" in stdout diff --git a/test_runner/regress/test_pitr_gc.py b/test_runner/regress/test_pitr_gc.py index 871a31b9ba..d983d77e72 100644 --- a/test_runner/regress/test_pitr_gc.py +++ b/test_runner/regress/test_pitr_gc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.common_types import TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_postgres_version.py b/test_runner/regress/test_postgres_version.py index d8626c15a5..5eb743809f 100644 --- a/test_runner/regress/test_postgres_version.py +++ b/test_runner/regress/test_postgres_version.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re from pathlib import Path diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index d2b8c2ed8b..f598900af9 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import asyncio import json import subprocess import time import urllib.parse -from typing import Any, List, Optional, Tuple +from typing import TYPE_CHECKING import psycopg2 import pytest import requests from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres +if TYPE_CHECKING: + from typing import Any, Optional + + GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" @@ -222,7 +228,7 @@ def test_sql_over_http_serverless_driver(static_proxy: NeonProxy): def test_sql_over_http(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser") - def q(sql: str, params: Optional[List[Any]] = None) -> Any: + def q(sql: str, params: Optional[list[Any]] = None) -> Any: params = params or [] connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" response = requests.post( @@ -285,7 +291,7 @@ def test_sql_over_http_db_name_with_space(static_proxy: NeonProxy): ) ) - def q(sql: str, params: Optional[List[Any]] = None) -> Any: + def q(sql: str, params: Optional[list[Any]] = None) -> Any: params = params or [] connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/{urllib.parse.quote(db)}" response = requests.post( @@ -304,7 +310,7 @@ def test_sql_over_http_db_name_with_space(static_proxy: NeonProxy): def test_sql_over_http_output_options(static_proxy: NeonProxy): static_proxy.safe_psql("create role http2 with login password 'http2' superuser") - def q(sql: str, raw_text: bool, array_mode: bool, params: Optional[List[Any]] = None) -> Any: + def q(sql: str, raw_text: bool, array_mode: bool, params: Optional[list[Any]] = None) -> Any: params = params or [] connstr = ( f"postgresql://http2:http2@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" @@ -340,7 +346,7 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser") def qq( - queries: List[Tuple[str, Optional[List[Any]]]], + queries: list[tuple[str, Optional[list[Any]]]], read_only: bool = False, deferrable: bool = False, ) -> Any: diff --git a/test_runner/regress/test_proxy_allowed_ips.py b/test_runner/regress/test_proxy_allowed_ips.py index 7a804114ba..902da1942e 100644 --- a/test_runner/regress/test_proxy_allowed_ips.py +++ b/test_runner/regress/test_proxy_allowed_ips.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import psycopg2 import pytest from fixtures.neon_fixtures import ( diff --git a/test_runner/regress/test_proxy_metric_collection.py b/test_runner/regress/test_proxy_metric_collection.py index f57b47f4da..dd63256388 100644 --- a/test_runner/regress/test_proxy_metric_collection.py +++ b/test_runner/regress/test_proxy_metric_collection.py @@ -1,5 +1,7 @@ +from __future__ import annotations + +from collections.abc import Iterator from pathlib import Path -from typing import Iterator import pytest from fixtures.log_helper import log diff --git a/test_runner/regress/test_proxy_websockets.py b/test_runner/regress/test_proxy_websockets.py index 6211446a40..071ca7c54e 100644 --- a/test_runner/regress/test_proxy_websockets.py +++ b/test_runner/regress/test_proxy_websockets.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ssl import pytest @@ -53,7 +55,7 @@ async def test_websockets(static_proxy: NeonProxy): assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message" assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated" - query_message = "SELECT 1".encode("utf-8") + b"\0" + query_message = b"SELECT 1" + b"\0" length = (4 + len(query_message)).to_bytes(4, byteorder="big") await websocket.send([b"Q", length, query_message]) @@ -132,7 +134,7 @@ async def test_websockets_pipelined(static_proxy: NeonProxy): auth_message = password.encode("utf-8") + b"\0" length1 = (4 + len(auth_message)).to_bytes(4, byteorder="big") - query_message = "SELECT 1".encode("utf-8") + b"\0" + query_message = b"SELECT 1" + b"\0" length2 = (4 + len(query_message)).to_bytes(4, byteorder="big") await websocket.send( length0 diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index 78798c5abf..471a3b406a 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import closing from fixtures.log_helper import log diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index b08fcc0da1..30c69cb883 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time import pytest diff --git a/test_runner/regress/test_recovery.py b/test_runner/regress/test_recovery.py index 8556103458..b43a443149 100644 --- a/test_runner/regress/test_recovery.py +++ b/test_runner/regress/test_recovery.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from contextlib import closing diff --git a/test_runner/regress/test_remote_storage.py b/test_runner/regress/test_remote_storage.py index c955dce4dc..79b5ebe39a 100644 --- a/test_runner/regress/test_remote_storage.py +++ b/test_runner/regress/test_remote_storage.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import os import queue import shutil import threading import time -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId @@ -35,6 +37,9 @@ from fixtures.utils import ( ) from requests import ReadTimeout +if TYPE_CHECKING: + from typing import Optional + # # Tests that a piece of data is backed up and restored correctly: @@ -423,7 +428,7 @@ def test_remote_timeline_client_calls_started_metric( assert timeline_id is not None wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id) - calls_started: Dict[Tuple[str, str], List[int]] = { + calls_started: dict[tuple[str, str], list[int]] = { ("layer", "upload"): [0], ("index", "upload"): [0], ("layer", "delete"): [0], diff --git a/test_runner/regress/test_replica_start.py b/test_runner/regress/test_replica_start.py index d5e92b92d1..e81e7dad76 100644 --- a/test_runner/regress/test_replica_start.py +++ b/test_runner/regress/test_replica_start.py @@ -20,6 +20,8 @@ from shutdown checkpoint, using the CLOG scanning mechanism, waiting for running-xacts record and for in-progress transactions to finish etc. """ +from __future__ import annotations + import threading from contextlib import closing diff --git a/test_runner/regress/test_s3_restore.py b/test_runner/regress/test_s3_restore.py index 721c391544..bedc9b5865 100644 --- a/test_runner/regress/test_s3_restore.py +++ b/test_runner/regress/test_s3_restore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from datetime import datetime, timezone diff --git a/test_runner/regress/test_setup.py b/test_runner/regress/test_setup.py index 02710fc807..dfbbd575b7 100644 --- a/test_runner/regress/test_setup.py +++ b/test_runner/regress/test_setup.py @@ -1,5 +1,7 @@ """Tests for the code in test fixtures""" +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index a3d4b5baca..b1abcaa763 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import os import time from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING import pytest import requests @@ -21,9 +23,13 @@ from fixtures.remote_storage import s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload from pytest_httpserver import HTTPServer +from typing_extensions import override from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Optional, Union + def test_sharding_smoke( neon_env_builder: NeonEnvBuilder, @@ -635,7 +641,7 @@ def test_sharding_split_stripe_size( tenant_id = env.initial_tenant assert len(notifications) == 1 - expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = { + expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -651,7 +657,7 @@ def test_sharding_split_stripe_size( # Check that we ended up with the stripe size that we expected, both on the pageserver # and in the notifications to compute assert len(notifications) == 2 - expect_after: Dict[str, Union[List[Dict[str, int]], str, None, int]] = { + expect_after: dict[str, Union[list[dict[str, int]], str, None, int]] = { "tenant_id": str(env.initial_tenant), "stripe_size": new_stripe_size, "shards": [ @@ -949,6 +955,7 @@ class PageserverFailpoint(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.allowed_errors.extend( @@ -956,19 +963,23 @@ class PageserverFailpoint(Failure): ) pageserver.http_client().configure_failpoints((self.failpoint, "return(1)")) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.http_client().configure_failpoints((self.failpoint, "off")) if self._mitigate: env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"}) + @override def expect_available(self): return True + @override def can_mitigate(self): return self._mitigate - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -978,9 +989,11 @@ class StorageControllerFailpoint(Failure): self.pageserver_id = None self.action = action + @override def apply(self, env: NeonEnv): env.storage_controller.configure_failpoints((self.failpoint, self.action)) + @override def clear(self, env: NeonEnv): if "panic" in self.action: log.info("Restarting storage controller after panic") @@ -989,16 +1002,19 @@ class StorageControllerFailpoint(Failure): else: env.storage_controller.configure_failpoints((self.failpoint, "off")) + @override def expect_available(self): # Controller panics _do_ leave pageservers available, but our test code relies # on using the locate API to update configurations in Workload, so we must skip # these actions when the controller has been panicked. return "panic" not in self.action + @override def can_mitigate(self): return False - def fails_forward(self, env): + @override + def fails_forward(self, env: NeonEnv): # Edge case: the very last failpoint that simulates a DB connection error, where # the abort path will fail-forward and result in a complete split. fail_forward = self.failpoint == "shard-split-post-complete" @@ -1012,6 +1028,7 @@ class StorageControllerFailpoint(Failure): return fail_forward + @override def expect_exception(self): if "panic" in self.action: return requests.exceptions.ConnectionError @@ -1024,18 +1041,22 @@ class NodeKill(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.stop(immediate=True) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.start() + @override def expect_available(self): return False - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -1054,21 +1075,26 @@ class CompositeFailure(Failure): self.pageserver_id = f.pageserver_id break + @override def apply(self, env: NeonEnv): for f in self.failures: f.apply(env) - def clear(self, env): + @override + def clear(self, env: NeonEnv): for f in self.failures: f.clear(env) + @override def expect_available(self): return all(f.expect_available() for f in self.failures) - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): for f in self.failures: f.mitigate(env) + @override def expect_exception(self): expect = set(f.expect_exception() for f in self.failures) @@ -1206,7 +1232,7 @@ def test_sharding_split_failures( assert attached_count == initial_shard_count - def assert_split_done(exclude_ps_id=None) -> None: + def assert_split_done(exclude_ps_id: Optional[int] = None) -> None: secondary_count = 0 attached_count = 0 for ps in env.pageservers: diff --git a/test_runner/regress/test_sni_router.py b/test_runner/regress/test_sni_router.py index 4336e6551d..402f27b384 100644 --- a/test_runner/regress/test_sni_router.py +++ b/test_runner/regress/test_sni_router.py @@ -1,14 +1,19 @@ +from __future__ import annotations + import socket import subprocess from pathlib import Path from types import TracebackType -from typing import Optional, Type +from typing import TYPE_CHECKING import backoff from fixtures.log_helper import log from fixtures.neon_fixtures import PgProtocol, VanillaPostgres from fixtures.port_distributor import PortDistributor +if TYPE_CHECKING: + from typing import Optional + def generate_tls_cert(cn, certout, keyout): subprocess.run( @@ -53,7 +58,7 @@ class PgSniRouter(PgProtocol): self._popen: Optional[subprocess.Popen[bytes]] = None self.test_output_dir = test_output_dir - def start(self) -> "PgSniRouter": + def start(self) -> PgSniRouter: assert self._popen is None args = [ str(self.neon_binpath / "pg_sni_router"), @@ -86,12 +91,12 @@ class PgSniRouter(PgProtocol): if self._popen: self._popen.wait(timeout=2) - def __enter__(self) -> "PgSniRouter": + def __enter__(self) -> PgSniRouter: return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ): diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 016d36301b..1dcc37c407 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import json import threading @@ -5,8 +7,9 @@ import time from collections import defaultdict from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING +import fixtures.utils import pytest from fixtures.auth_tokens import TokenScope from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -36,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import RemoteStorageKind, s3_storage from fixtures.storage_controller_proxy import StorageControllerProxy -from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until +from fixtures.utils import ( + run_pg_bench_small, + subprocess_capture, + wait_until, +) from fixtures.workload import Workload from mypy_boto3_s3.type_defs import ( ObjectTypeDef, @@ -46,6 +53,9 @@ from urllib3 import Retry from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response +if TYPE_CHECKING: + from typing import Any, Optional, Union + def get_node_shard_counts(env: NeonEnv, tenant_ids): counts: defaultdict[int, int] = defaultdict(int) @@ -55,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids): return counts -def test_storage_controller_smoke( - neon_env_builder: NeonEnvBuilder, -): +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination): """ Test the basic lifecycle of a storage controller: - Restarting @@ -490,7 +499,7 @@ def test_storage_controller_compute_hook( # Initial notification from tenant creation assert len(notifications) == 1 - expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = { + expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -597,7 +606,7 @@ def test_storage_controller_stuck_compute_hook( # Initial notification from tenant creation assert len(notifications) == 1 - expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = { + expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -834,7 +843,7 @@ def test_storage_controller_s3_time_travel_recovery( # Simulate a "disaster": delete some random files from remote storage for one of the shards assert env.pageserver_remote_storage shard_id_for_list = "0002" - objects: List[ObjectTypeDef] = list_prefix( + objects: list[ObjectTypeDef] = list_prefix( env.pageserver_remote_storage, f"tenants/{tenant_id}-{shard_id_for_list}/timelines/{timeline_id}/", ).get("Contents", []) @@ -885,7 +894,7 @@ def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder): api = env.storage_controller_api tenant_id = TenantId.generate() - body: Dict[str, Any] = {"new_tenant_id": str(tenant_id)} + body: dict[str, Any] = {"new_tenant_id": str(tenant_id)} env.storage_controller.allowed_errors.append(".*Unauthorized.*") env.storage_controller.allowed_errors.append(".*Forbidden.*") @@ -1033,7 +1042,7 @@ def test_storage_controller_tenant_deletion( ) # Break the compute hook: we are checking that deletion does not depend on the compute hook being available - def break_hook(): + def break_hook(_body: Any): raise RuntimeError("Unexpected call to compute hook") compute_reconfigure_listener.register_on_notify(break_hook) @@ -1228,9 +1237,9 @@ def test_storage_controller_heartbeats( log.info(f"{node_to_tenants=}") # Check that all the tenants have been attached - assert sum((len(ts) for ts in node_to_tenants.values())) == len(tenant_ids) + assert sum(len(ts) for ts in node_to_tenants.values()) == len(tenant_ids) # Check that each node got one tenant - assert all((len(ts) == 1 for ts in node_to_tenants.values())) + assert all(len(ts) == 1 for ts in node_to_tenants.values()) wait_until(10, 1, tenants_placed) @@ -1295,11 +1304,11 @@ def test_storage_controller_heartbeats( node_to_tenants = build_node_to_tenants_map(env) log.info(f"Back online: {node_to_tenants=}") - # ... expecting the storage controller to reach a consistent state - def storage_controller_consistent(): - env.storage_controller.consistency_check() + # ... background reconciliation may need to run to clean up the location on the node that was offline + env.storage_controller.reconcile_until_idle() - wait_until(30, 1, storage_controller_consistent) + # ... expecting the storage controller to reach a consistent state + env.storage_controller.consistency_check() def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder): @@ -2071,10 +2080,10 @@ def test_storage_controller_metadata_health( def update_and_query_metadata_health( env: NeonEnv, - healthy: List[TenantShardId], - unhealthy: List[TenantShardId], + healthy: list[TenantShardId], + unhealthy: list[TenantShardId], outdated_duration: str = "1h", - ) -> Tuple[Set[str], Set[str]]: + ) -> tuple[set[str], set[str]]: """ Update metadata health. Then list tenant shards with unhealthy and outdated metadata health status. @@ -2389,7 +2398,7 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB env.storage_controller.reconcile_until_idle() attached_id = int(env.storage_controller.locate(env.initial_tenant)[0]["node_id"]) - attached = next((ps for ps in env.pageservers if ps.id == attached_id)) + attached = next(ps for ps in env.pageservers if ps.id == attached_id) def attached_is_draining(): details = env.storage_controller.node_status(attached.id) diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index 7ecd0cf748..05db0fe977 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import os import pprint import shutil import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -18,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + @pytest.mark.parametrize("shard_count", [None, 4]) def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): diff --git a/test_runner/regress/test_subscriber_restart.py b/test_runner/regress/test_subscriber_restart.py index e67001ef41..d37eeb1e6e 100644 --- a/test_runner/regress/test_subscriber_restart.py +++ b/test_runner/regress/test_subscriber_restart.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import threading import time diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index 82075bd723..7a46f0140c 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content diff --git a/test_runner/regress/test_tenant_conf.py b/test_runner/regress/test_tenant_conf.py index d13cbe45e9..1dd46ec3d1 100644 --- a/test_runner/regress/test_tenant_conf.py +++ b/test_runner/regress/test_tenant_conf.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import json -from typing import Any, Dict +from typing import TYPE_CHECKING from fixtures.common_types import Lsn from fixtures.neon_fixtures import ( @@ -10,11 +12,14 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Any + def test_tenant_config(neon_env_builder: NeonEnvBuilder): """Test per tenant configuration""" - def set_some_nondefault_global_config(ps_cfg: Dict[str, Any]): + def set_some_nondefault_global_config(ps_cfg: dict[str, Any]): ps_cfg["page_cache_size"] = 444 ps_cfg["wait_lsn_timeout"] = "111 s" diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index eafd159ac0..294c1248c5 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from threading import Thread diff --git a/test_runner/regress/test_tenant_detach.py b/test_runner/regress/test_tenant_detach.py index 6de22f262d..59c14b3263 100644 --- a/test_runner/regress/test_tenant_detach.py +++ b/test_runner/regress/test_tenant_detach.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import enum import random import time from threading import Thread -from typing import List, Optional +from typing import TYPE_CHECKING import asyncpg import pytest @@ -26,6 +28,10 @@ from fixtures.remote_storage import ( from fixtures.utils import query_scalar, wait_until from prometheus_client.samples import Sample +if TYPE_CHECKING: + from typing import Optional + + # In tests that overlap endpoint activity with tenant attach/detach, there are # a variety of warnings that the page service may emit when it cannot acquire # an active tenant to serve a request @@ -492,7 +498,7 @@ def test_metrics_while_ignoring_broken_tenant_and_reloading( r".* Changing Active tenant to Broken state, reason: broken from test" ) - def only_int(samples: List[Sample]) -> Optional[int]: + def only_int(samples: list[Sample]) -> Optional[int]: if len(samples) == 1: return int(samples[0].value) assert len(samples) == 0 diff --git a/test_runner/regress/test_tenant_relocation.py b/test_runner/regress/test_tenant_relocation.py index 645e22af1f..5561a128b7 100644 --- a/test_runner/regress/test_tenant_relocation.py +++ b/test_runner/regress/test_tenant_relocation.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import os import shutil import threading import time from contextlib import closing, contextmanager from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId @@ -25,6 +27,9 @@ from fixtures.utils import ( wait_until, ) +if TYPE_CHECKING: + from typing import Any, Optional + def assert_abs_margin_ratio(a: float, b: float, margin_ratio: float): assert abs(a - b) / a < margin_ratio, abs(a - b) / a @@ -74,7 +79,7 @@ def populate_branch( ps_http: PageserverHttpClient, create_table: bool, expected_sum: Optional[int], -) -> Tuple[TimelineId, Lsn]: +) -> tuple[TimelineId, Lsn]: # insert some data with pg_cur(endpoint) as cur: cur.execute("SHOW neon.timeline_id") @@ -120,7 +125,7 @@ def check_timeline_attached( new_pageserver_http_client: PageserverHttpClient, tenant_id: TenantId, timeline_id: TimelineId, - old_timeline_detail: Dict[str, Any], + old_timeline_detail: dict[str, Any], old_current_lsn: Lsn, ): # new pageserver should be in sync (modulo wal tail or vacuum activity) with the old one because there was no new writes since checkpoint diff --git a/test_runner/regress/test_tenant_size.py b/test_runner/regress/test_tenant_size.py index 867c0021cd..9ea09d10d7 100644 --- a/test_runner/regress/test_tenant_size.py +++ b/test_runner/regress/test_tenant_size.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import List, Tuple import pytest from fixtures.common_types import Lsn, TenantId, TimelineId @@ -302,7 +303,7 @@ def test_single_branch_get_tenant_size_grows( http_client = env.pageserver.http_client() - collected_responses: List[Tuple[str, Lsn, int]] = [] + collected_responses: list[tuple[str, Lsn, int]] = [] size_debug_file = open(test_output_dir / "size_debug.html", "w") @@ -313,7 +314,7 @@ def test_single_branch_get_tenant_size_grows( http_client: PageserverHttpClient, tenant_id: TenantId, timeline_id: TimelineId, - ) -> Tuple[Lsn, int]: + ) -> tuple[Lsn, int]: consistent = False size_debug = None diff --git a/test_runner/regress/test_tenant_tasks.py b/test_runner/regress/test_tenant_tasks.py index 2bf930d767..72183f5778 100644 --- a/test_runner/regress/test_tenant_tasks.py +++ b/test_runner/regress/test_tenant_tasks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index 7b194d40dd..95dc0fec78 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import os import time @@ -5,7 +7,6 @@ from contextlib import closing from datetime import datetime from itertools import chain from pathlib import Path -from typing import List import pytest import requests @@ -272,7 +273,7 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert cur.fetchone() == (5000050000,) endpoint.stop() - def get_ps_metric_samples_for_tenant(tenant_id: TenantId) -> List[Sample]: + def get_ps_metric_samples_for_tenant(tenant_id: TenantId) -> list[Sample]: ps_metrics = env.pageserver.http_client().get_metrics() samples = [] for metric_name in ps_metrics.metrics: @@ -459,7 +460,7 @@ def test_pageserver_metrics_many_relations(neon_env_builder: NeonEnvBuilder): "pageserver_directory_entries_count", {"tenant_id": str(env.initial_tenant)} ) - def only_int(samples: List[Sample]) -> int: + def only_int(samples: list[Sample]) -> int: assert len(samples) == 1 return int(samples[0].value) diff --git a/test_runner/regress/test_tenants_with_remote_storage.py b/test_runner/regress/test_tenants_with_remote_storage.py index 9310786da7..8d3ddf7e54 100644 --- a/test_runner/regress/test_tenants_with_remote_storage.py +++ b/test_runner/regress/test_tenants_with_remote_storage.py @@ -6,10 +6,11 @@ # checkpoint_distance setting so that a lot of layer files are created. # +from __future__ import annotations + import asyncio import os from pathlib import Path -from typing import List, Tuple from fixtures.common_types import Lsn, TenantId, TimelineId from fixtures.log_helper import log @@ -62,7 +63,7 @@ async def all_tenants_workload(env: NeonEnv, tenants_endpoints): def test_tenants_many(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() - tenants_endpoints: List[Tuple[TenantId, Endpoint]] = [] + tenants_endpoints: list[tuple[TenantId, Endpoint]] = [] for _ in range(1, 5): # Use a tiny checkpoint distance, to create a lot of layers quickly diff --git a/test_runner/regress/test_threshold_based_eviction.py b/test_runner/regress/test_threshold_based_eviction.py index 094dd20529..5f211ec4d4 100644 --- a/test_runner/regress/test_threshold_based_eviction.py +++ b/test_runner/regress/test_threshold_based_eviction.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import time from dataclasses import dataclass -from typing import List, Set, Tuple from fixtures.log_helper import log from fixtures.neon_fixtures import ( @@ -116,8 +117,8 @@ def test_threshold_based_eviction( # wait for evictions and assert that they stabilize @dataclass class ByLocalAndRemote: - remote_layers: Set[str] - local_layers: Set[str] + remote_layers: set[str] + local_layers: set[str] class MapInfoProjection: def __init__(self, info: LayerMapInfo): @@ -149,7 +150,7 @@ def test_threshold_based_eviction( consider_stable_when_no_change_for_seconds = 3 * eviction_threshold poll_interval = eviction_threshold / 3 started_waiting_at = time.time() - map_info_changes: List[Tuple[float, MapInfoProjection]] = [] + map_info_changes: list[tuple[float, MapInfoProjection]] = [] while time.time() - started_waiting_at < observation_window: current = ( time.time(), diff --git a/test_runner/regress/test_timeline_archive.py b/test_runner/regress/test_timeline_archive.py index 16e0521890..841707d32e 100644 --- a/test_runner/regress/test_timeline_archive.py +++ b/test_runner/regress/test_timeline_archive.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fixtures.common_types import TenantId, TimelineArchivalState, TimelineId from fixtures.neon_fixtures import ( diff --git a/test_runner/regress/test_timeline_delete.py b/test_runner/regress/test_timeline_delete.py index 7b6f6ac3c6..306f22acf9 100644 --- a/test_runner/regress/test_timeline_delete.py +++ b/test_runner/regress/test_timeline_delete.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import os import queue diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index 7f148a4b9b..0c8554bb54 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import enum import threading @@ -5,7 +7,6 @@ import time from concurrent.futures import ThreadPoolExecutor from queue import Empty, Queue from threading import Barrier -from typing import List, Set, Tuple import pytest from fixtures.common_types import Lsn, TimelineId @@ -48,7 +49,7 @@ class Branchpoint(str, enum.Enum): return self.value @staticmethod - def all() -> List["Branchpoint"]: + def all() -> list[Branchpoint]: return [ Branchpoint.EARLIER, Branchpoint.AT_L0, @@ -473,7 +474,7 @@ def test_compaction_induced_by_detaches_in_history( more_good_numbers = range(0, 3) - branches: List[Tuple[str, TimelineId]] = [("main", env.initial_timeline)] + branches: list[tuple[str, TimelineId]] = [("main", env.initial_timeline)] for num in more_good_numbers: branch_name = f"br-{len(branches)}" @@ -1270,7 +1271,7 @@ def test_retried_detach_ancestor_after_failed_reparenting(neon_env_builder: Neon {"request_type": "copy_object", "result": "ok"}, ) - def reparenting_progress(timelines: List[TimelineId]) -> Tuple[int, Set[TimelineId]]: + def reparenting_progress(timelines: list[TimelineId]) -> tuple[int, set[TimelineId]]: reparented = 0 not_reparented = set() for timeline in timelines: @@ -1306,7 +1307,7 @@ def test_retried_detach_ancestor_after_failed_reparenting(neon_env_builder: Neon http.configure_failpoints(("timeline-detach-ancestor::allow_one_reparented", "return")) - not_reparented: Set[TimelineId] = set() + not_reparented: set[TimelineId] = set() # tracked offset in the pageserver log which is at least at the most recent activation offset = None diff --git a/test_runner/regress/test_timeline_gc_blocking.py b/test_runner/regress/test_timeline_gc_blocking.py index 1540cbbcee..c19c78e251 100644 --- a/test_runner/regress/test_timeline_gc_blocking.py +++ b/test_runner/regress/test_timeline_gc_blocking.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import List, Optional +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -12,6 +14,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_timeline_detail_404 +if TYPE_CHECKING: + from typing import Optional + @pytest.mark.parametrize("sharded", [True, False]) def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool): @@ -98,7 +103,7 @@ class ScrollableLog: @dataclass(frozen=True) class ManyPageservers: - many: List[ScrollableLog] + many: list[ScrollableLog] def assert_log_contains(self, what: str): for one in self.many: diff --git a/test_runner/regress/test_timeline_size.py b/test_runner/regress/test_timeline_size.py index aa77474097..85c6d17142 100644 --- a/test_runner/regress/test_timeline_size.py +++ b/test_runner/regress/test_timeline_size.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import math import random diff --git a/test_runner/regress/test_truncate.py b/test_runner/regress/test_truncate.py index 4fc0601a18..946dab2676 100644 --- a/test_runner/regress/test_truncate.py +++ b/test_runner/regress/test_truncate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from fixtures.neon_fixtures import NeonEnvBuilder diff --git a/test_runner/regress/test_twophase.py b/test_runner/regress/test_twophase.py index 1d9fe9d21d..e37e8dd3e8 100644 --- a/test_runner/regress/test_twophase.py +++ b/test_runner/regress/test_twophase.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from pathlib import Path diff --git a/test_runner/regress/test_unlogged.py b/test_runner/regress/test_unlogged.py index 4431ccd959..a89391425e 100644 --- a/test_runner/regress/test_unlogged.py +++ b/test_runner/regress/test_unlogged.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from fixtures.neon_fixtures import NeonEnv, fork_at_current_lsn from fixtures.pg_version import PgVersion diff --git a/test_runner/regress/test_vm_bits.py b/test_runner/regress/test_vm_bits.py index ae1b6fdab3..d4c2ca7e07 100644 --- a/test_runner/regress/test_vm_bits.py +++ b/test_runner/regress/test_vm_bits.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from contextlib import closing diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 44ca9f90a4..d803cd7c78 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import filecmp import logging import os @@ -12,7 +14,7 @@ from contextlib import closing from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING import psycopg2 import psycopg2.errors @@ -56,6 +58,9 @@ from fixtures.utils import ( wait_until, ) +if TYPE_CHECKING: + from typing import Any, Optional + def wait_lsn_force_checkpoint( tenant_id: TenantId, @@ -124,8 +129,8 @@ class TimelineMetrics: timeline_id: TimelineId last_record_lsn: Lsn # One entry per each Safekeeper, order is the same - flush_lsns: List[Lsn] = field(default_factory=list) - commit_lsns: List[Lsn] = field(default_factory=list) + flush_lsns: list[Lsn] = field(default_factory=list) + commit_lsns: list[Lsn] = field(default_factory=list) # Run page server and multiple acceptors, and multiple compute nodes running @@ -152,7 +157,7 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant - def collect_metrics(message: str) -> List[TimelineMetrics]: + def collect_metrics(message: str) -> list[TimelineMetrics]: with env.pageserver.http_client() as pageserver_http: timeline_details = [ pageserver_http.timeline_detail( @@ -765,7 +770,7 @@ class ProposerPostgres(PgProtocol): stdout_filename = basepath + ".stdout" - with open(stdout_filename, "r") as stdout_f: + with open(stdout_filename) as stdout_f: stdout = stdout_f.read() return Lsn(stdout.strip("\n ")) @@ -934,7 +939,7 @@ def test_timeline_status(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): assert debug_dump_1["config"]["id"] == env.safekeepers[0].id -class DummyConsumer(object): +class DummyConsumer: def __call__(self, msg): pass @@ -1162,7 +1167,7 @@ def is_flush_lsn_aligned(sk_http_clis, tenant_id, timeline_id): # Assert by xxd that WAL on given safekeepers is identical. No compute must be # running for this to be reliable. -def cmp_sk_wal(sks: List[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): +def cmp_sk_wal(sks: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): assert len(sks) >= 2, "cmp_sk_wal makes sense with >= 2 safekeepers passed" sk_http_clis = [sk.http_client() for sk in sks] @@ -1448,12 +1453,12 @@ class SafekeeperEnv: self.pg_bin = pg_bin self.num_safekeepers = num_safekeepers self.bin_safekeeper = str(neon_binpath / "safekeeper") - self.safekeepers: Optional[List[subprocess.CompletedProcess[Any]]] = None + self.safekeepers: Optional[list[subprocess.CompletedProcess[Any]]] = None self.postgres: Optional[ProposerPostgres] = None self.tenant_id: Optional[TenantId] = None self.timeline_id: Optional[TimelineId] = None - def init(self) -> "SafekeeperEnv": + def init(self) -> SafekeeperEnv: assert self.postgres is None, "postgres is already initialized" assert self.safekeepers is None, "safekeepers are already initialized" @@ -1534,7 +1539,7 @@ class SafekeeperEnv: def kill_safekeeper(self, sk_dir): """Read pid file and kill process""" pid_file = os.path.join(sk_dir, "safekeeper.pid") - with open(pid_file, "r") as f: + with open(pid_file) as f: pid = int(f.read()) log.info(f"Killing safekeeper with pid {pid}") os.kill(pid, signal.SIGKILL) @@ -1593,7 +1598,7 @@ def test_replace_safekeeper(neon_env_builder: NeonEnvBuilder): sum_after = query_scalar(cur, "SELECT SUM(key) FROM t") assert sum_after == sum_before + 5000050000 - def show_statuses(safekeepers: List[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): + def show_statuses(safekeepers: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): for sk in safekeepers: http_cli = sk.http_client() try: @@ -1802,7 +1807,7 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder, live_sk_change: bool): sum_after = query_scalar(cur, "SELECT SUM(key) FROM t") assert sum_after == sum_before + 5000050000 - def show_statuses(safekeepers: List[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): + def show_statuses(safekeepers: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId): for sk in safekeepers: http_cli = sk.http_client(auth_token=env.auth_keys.generate_tenant_token(tenant_id)) try: @@ -2011,14 +2016,14 @@ def test_idle_reconnections(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant timeline_id = env.create_branch("test_idle_reconnections") - def collect_stats() -> Dict[str, float]: + def collect_stats() -> dict[str, float]: # we need to collect safekeeper_pg_queries_received_total metric from all safekeepers sk_metrics = [ parse_metrics(sk.http_client().get_metrics_str(), f"safekeeper_{sk.id}") for sk in env.safekeepers ] - total: Dict[str, float] = {} + total: dict[str, float] = {} for sk in sk_metrics: queries_received = sk.query_all("safekeeper_pg_queries_received_total") @@ -2309,12 +2314,12 @@ def test_s3_eviction( ] if delete_offloaded_wal: neon_env_builder.safekeeper_extra_opts.append("--delete-offloaded-wal") - - env = neon_env_builder.init_start( - initial_tenant_conf={ - "checkpoint_timeout": "100ms", - } - ) + # make lagging_wal_timeout small to force pageserver quickly forget about + # safekeeper after it stops sending updates (timeline is deactivated) to + # make test faster. Won't be needed with + # https://github.com/neondatabase/neon/issues/8148 fixed. + initial_tenant_conf = {"lagging_wal_timeout": "1s", "checkpoint_timeout": "100ms"} + env = neon_env_builder.init_start(initial_tenant_conf=initial_tenant_conf) n_timelines = 5 @@ -2402,9 +2407,37 @@ def test_s3_eviction( and sk.log_contains("successfully restored evicted timeline") for sk in env.safekeepers ) - assert event_metrics_seen + # test safekeeper_evicted_timelines metric + log.info("testing safekeeper_evicted_timelines metric") + # checkpoint pageserver to force remote_consistent_lsn update + for i in range(n_timelines): + ps_client.timeline_checkpoint(env.initial_tenant, timelines[i], wait_until_uploaded=True) + for ep in endpoints: + log.info(ep.is_running()) + sk = env.safekeepers[0] + + # all timelines must be evicted eventually + def all_evicted(): + n_evicted = sk.http_client().get_metric_value("safekeeper_evicted_timelines") + assert n_evicted # make mypy happy + assert int(n_evicted) == n_timelines + + wait_until(60, 0.5, all_evicted) + # restart should preserve the metric value + sk.stop().start() + wait_until(60, 0.5, all_evicted) + # and endpoint start should reduce is + endpoints[0].start() + + def one_unevicted(): + n_evicted = sk.http_client().get_metric_value("safekeeper_evicted_timelines") + assert n_evicted # make mypy happy + assert int(n_evicted) < n_timelines + + wait_until(60, 0.5, one_unevicted) + # Test resetting uploaded partial segment state. def test_backup_partial_reset(neon_env_builder: NeonEnvBuilder): diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index 74d114e976..92306469f8 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import random import time from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import TYPE_CHECKING import asyncpg import pytest @@ -13,10 +15,14 @@ from fixtures.log_helper import getLogger from fixtures.neon_fixtures import Endpoint, NeonEnv, NeonEnvBuilder, Safekeeper from fixtures.remote_storage import RemoteStorageKind +if TYPE_CHECKING: + from typing import Optional + + log = getLogger("root.safekeeper_async") -class BankClient(object): +class BankClient: def __init__(self, conn: asyncpg.Connection, n_accounts, init_amount): self.conn: asyncpg.Connection = conn self.n_accounts = n_accounts @@ -65,7 +71,7 @@ async def bank_transfer(conn: asyncpg.Connection, from_uid, to_uid, amount): ) -class WorkerStats(object): +class WorkerStats: def __init__(self, n_workers): self.counters = [0] * n_workers self.running = True @@ -148,7 +154,7 @@ async def wait_for_lsn( async def run_restarts_under_load( env: NeonEnv, endpoint: Endpoint, - acceptors: List[Safekeeper], + acceptors: list[Safekeeper], n_workers=10, n_accounts=100, init_amount=100000, @@ -329,7 +335,7 @@ def test_compute_restarts(neon_env_builder: NeonEnvBuilder): asyncio.run(run_compute_restarts(env)) -class BackgroundCompute(object): +class BackgroundCompute: MAX_QUERY_GAP_SECONDS = 2 def __init__(self, index: int, env: NeonEnv, branch: str): @@ -339,7 +345,7 @@ class BackgroundCompute(object): self.running = False self.stopped = False self.total_tries = 0 - self.successful_queries: List[int] = [] + self.successful_queries: list[int] = [] async def run(self): if self.running: @@ -634,7 +640,7 @@ class RaceConditionTest: # shut down random subset of safekeeper, sleep, wake them up, rinse, repeat -async def xmas_garland(safekeepers: List[Safekeeper], data: RaceConditionTest): +async def xmas_garland(safekeepers: list[Safekeeper], data: RaceConditionTest): while not data.is_stopped: data.iteration += 1 victims = [] @@ -693,7 +699,7 @@ def test_race_conditions(neon_env_builder: NeonEnvBuilder): # Check that pageserver can select safekeeper with largest commit_lsn # and switch if LSN is not updated for some time (NoWalTimeout). async def run_wal_lagging(env: NeonEnv, endpoint: Endpoint, test_output_dir: Path): - def adjust_safekeepers(env: NeonEnv, active_sk: List[bool]): + def adjust_safekeepers(env: NeonEnv, active_sk: list[bool]): # Change the pg ports of the inactive safekeepers in the config file to be # invalid, to make them unavailable to the endpoint. We use # ports 10, 11 and 12 to simulate unavailable safekeepers. diff --git a/test_runner/regress/test_wal_receiver.py b/test_runner/regress/test_wal_receiver.py index 3c73df68e0..be2aa2b346 100644 --- a/test_runner/regress/test_wal_receiver.py +++ b/test_runner/regress/test_wal_receiver.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import time -from typing import Any, Dict +from typing import TYPE_CHECKING from fixtures.common_types import Lsn, TenantId from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder +if TYPE_CHECKING: + from typing import Any + # Checks that pageserver's walreceiver state is printed in the logs during WAL wait timeout. # Ensures that walreceiver does not run without any data inserted and only starts after the insertion. @@ -43,7 +48,7 @@ def test_pageserver_lsn_wait_error_start(neon_env_builder: NeonEnvBuilder): # Kills one of the safekeepers and ensures that only the active ones are printed in the state. def test_pageserver_lsn_wait_error_safekeeper_stop(neon_env_builder: NeonEnvBuilder): # Trigger WAL wait timeout faster - def customize_pageserver_toml(ps_cfg: Dict[str, Any]): + def customize_pageserver_toml(ps_cfg: dict[str, Any]): ps_cfg["wait_lsn_timeout"] = "1s" tenant_config = ps_cfg.setdefault("tenant_config", {}) tenant_config["walreceiver_connect_timeout"] = "2s" diff --git a/test_runner/regress/test_wal_restore.py b/test_runner/regress/test_wal_restore.py index 46366f0e2c..05b6ad8a9b 100644 --- a/test_runner/regress/test_wal_restore.py +++ b/test_runner/regress/test_wal_restore.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import sys import tarfile import tempfile from pathlib import Path -from typing import List import pytest import zstandard @@ -165,7 +166,7 @@ def test_wal_restore_http(neon_env_builder: NeonEnvBuilder, broken_tenant: bool) if broken_tenant: ps_client.tenant_detach(tenant_id) - objects: List[ObjectTypeDef] = list_prefix( + objects: list[ObjectTypeDef] = list_prefix( env.pageserver_remote_storage, f"tenants/{tenant_id}/timelines/{timeline_id}/" ).get("Contents", []) for obj in objects: diff --git a/test_runner/regress/test_walredo_not_left_behind_on_detach.py b/test_runner/regress/test_walredo_not_left_behind_on_detach.py index ae8e276a1a..182e57b8a4 100644 --- a/test_runner/regress/test_walredo_not_left_behind_on_detach.py +++ b/test_runner/regress/test_walredo_not_left_behind_on_detach.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time import psutil diff --git a/test_runner/test_broken.py b/test_runner/test_broken.py index d710b53528..112e699395 100644 --- a/test_runner/test_broken.py +++ b/test_runner/test_broken.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pytest diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index ddfc14ce1b..0a90b6b6f7 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -17,9 +17,10 @@ license.workspace = true [dependencies] ahash = { version = "0.8" } anyhow = { version = "1", features = ["backtrace"] } +axum = { version = "0.7", features = ["ws"] } +axum-core = { version = "0.4", default-features = false, features = ["tracing"] } base64 = { version = "0.21", features = ["alloc"] } base64ct = { version = "1", default-features = false, features = ["std"] } -bitflags = { version = "2", default-features = false, features = ["std"] } bytes = { version = "1", features = ["serde"] } camino = { version = "1", default-features = false, features = ["serde1"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } @@ -43,11 +44,10 @@ hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper-582f2526e08bb6a0 = { package = "hyper", version = "0.14", features = ["full"] } -hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["http1", "http2", "server"] } -hyper-util = { version = "0.1", features = ["http1", "http2", "server", "tokio"] } +hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["full"] } +hyper-util = { version = "0.1", features = ["client-legacy", "server-auto", "service"] } indexmap = { version = "1", default-features = false, features = ["std"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -itertools-93f6ce9d446188ac = { package = "itertools", version = "0.10" } +itertools = { version = "0.12" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } @@ -58,16 +58,16 @@ num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } -prost = { version = "0.11" } +postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", default-features = false, features = ["with-serde_json-1"] } +prost = { version = "0.13", features = ["prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] } regex-syntax = { version = "0.8" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] } -rustls = { version = "0.21", features = ["dangerous_configuration"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } -serde_json = { version = "1", features = ["raw_value"] } +serde_json = { version = "1", features = ["alloc", "raw_value"] } sha2 = { version = "0.10", features = ["asm", "oid"] } signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] } smallvec = { version = "1", default-features = false, features = ["const_new", "write"] } @@ -77,10 +77,12 @@ sync_wrapper = { version = "0.1", default-features = false, features = ["futures tikv-jemalloc-sys = { version = "0.5" } time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } -tokio-rustls = { version = "0.24" } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", features = ["with-serde_json-1"] } +tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_edit = { version = "0.22", features = ["serde"] } -tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] } +tonic = { version = "0.12", features = ["tls-roots"] } +tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "util"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } url = { version = "2", features = ["serde"] } @@ -92,7 +94,6 @@ zstd-sys = { version = "2", default-features = false, features = ["legacy", "std [build-dependencies] ahash = { version = "0.8" } anyhow = { version = "1", features = ["backtrace"] } -bitflags = { version = "2", default-features = false, features = ["std"] } bytes = { version = "1", features = ["serde"] } cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } @@ -101,9 +102,7 @@ getrandom = { version = "0.2", default-features = false, features = ["std"] } half = { version = "2", default-features = false, features = ["num-traits"] } hashbrown = { version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } -itertools-93f6ce9d446188ac = { package = "itertools", version = "0.10" } -lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } +itertools = { version = "0.12" } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" } @@ -113,8 +112,9 @@ num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } +prettyplease = { version = "0.2", default-features = false, features = ["verbatim"] } proc-macro2 = { version = "1" } -prost = { version = "0.11" } +prost = { version = "0.13", features = ["prost-derive"] } quote = { version = "1" } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }