Compare commits

..

36 Commits

Author SHA1 Message Date
Heikki Linnakangas
7e175400ab Reduce noise from moto GET/PUT operations
Moto prints messages like this:

    127.0.0.1 - - [07/Oct/2024 12:35:16] "PUT /bucket-name/path?x-id=PutObject HTTP/1.1" 200 -

After the root logger adds its context information, this is what
actually gets printed to the log:

    2024-10-07 22:35:16.371 INFO [_internal.py:97] 127.0.0.1 - - [07/Oct/2024 22:35:16] "PUT /bucket-name/path?x-id=PutObject HTTP/1.1" 200 -

That's very verbose. Remove the hostname and the extra timestamp, to
make it a little less verbose. With this PR, the final output looks
like this:

    2024-10-07 22:35:16.371 INFO [_internal.py:97] "PUT /bucket-name/path?x-id=PutObject HTTP/1.1" 200 -
2024-10-09 18:53:08 +03:00
Fedor Dikarev
108a211917 added workflow Report Workflow Stats (#9330)
## Summary of changes
CI: Collect stats for Github Workflows Runs
2024-10-09 17:27:41 +02:00
Heikki Linnakangas
72ef0e0fa1 tests: Remove redundant log lines when stopping storage nodes (#9317)
The neon_cli functions print the command that gets executed, which
contains the same information.

Before:

    2024-10-07 22:32:28.884 INFO [neon_fixtures.py:3927] Stopping safekeeper 1
    2024-10-07 22:32:28.884 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 1"
    2024-10-07 22:32:28.989 INFO [neon_fixtures.py:3927] Stopping safekeeper 2
    2024-10-07 22:32:28.989 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 2"
    2024-10-07 22:32:29.93 INFO [neon_fixtures.py:3927] Stopping safekeeper 3
    2024-10-07 22:32:29.94 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 3"
    2024-10-07 22:32:29.251 INFO [neon_cli.py:450] Stopping pageserver with ['pageserver', 'stop', '--id=1']
    2024-10-07 22:32:29.251 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local pageserver stop --id=1"

After:

    2024-10-07 22:32:28.884 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 1"
    2024-10-07 22:32:28.989 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 2"
    2024-10-07 22:32:29.94 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local safekeeper stop 3"
    2024-10-07 22:32:29.251 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local pageserver stop --id=1"
2024-10-09 15:51:34 +03:00
Heikki Linnakangas
eb23d355a9 tests: Use ThreadedMotoServer python class to launch mock S3 server (#9313)
This is simpler than using subprocess.

One difference is in how moto's log output is now collected. Previously,
moto's logs went to stderr, and were collected and printed at the end of
the test by pytest, like this:

    2024-10-07T22:45:12.3705222Z ----------------------------- Captured stderr call -----------------------------
    2024-10-07T22:45:12.3705577Z 127.0.0.1 - - [07/Oct/2024 22:35:14] "PUT /pageserver-test-deletion-queue-2e6efa8245ec92a37a07004569c29eb7 HTTP/1.1" 200 -
    2024-10-07T22:45:12.3706181Z 127.0.0.1 - - [07/Oct/2024 22:35:15] "GET /pageserver-test-deletion-queue-2e6efa8245ec92a37a07004569c29eb7/?list-type=2&delimiter=/&prefix=/tenants/43da25eac0f41412696dd31b94dbb83c/timelines/ HTTP/1.1" 200 -
    2024-10-07T22:45:12.3706894Z 127.0.0.1 - - [07/Oct/2024 22:35:16] "PUT /pageserver-test-deletion-queue-2e6efa8245ec92a37a07004569c29eb7//tenants/43da25eac0f41412696dd31b94dbb83c/timelines/eabba5f0c1c72c8656d3ef1d85b98c1d/initdb.tar.zst?x-id=PutObject HTTP/1.1" 200 -

Note the timestamps: the timestamp at the beginning of the line is the
time that the stderr was dumped, i.e. the end of the test, which makes
those timestamps rather useless. The timestamp in the middle of the line
is when the operation actually happened, but it has only 1 s
granularity.

With this change, moto's log lines are printed in the "live log call"
section, as they happen, which makes the timestamps more useful:

    2024-10-08 12:12:31.129 INFO [_internal.py:97] 127.0.0.1 - - [08/Oct/2024 12:12:31] "GET /pageserver-test-deletion-queue-e24e7525d437e1874d8a52030dcabb4f/?list-type=2&delimiter=/&prefix=/tenants/7b6a16b1460eda5204083fba78bc360f/timelines/ HTTP/1.1" 200 -
    2024-10-08 12:12:32.612 INFO [_internal.py:97] 127.0.0.1 - - [08/Oct/2024 12:12:32] "PUT /pageserver-test-deletion-queue-e24e7525d437e1874d8a52030dcabb4f//tenants/7b6a16b1460eda5204083fba78bc360f/timelines/7ab4c2b67fa8c712cada207675139877/initdb.tar.zst?x-id=PutObject HTTP/1.1" 200 -
2024-10-09 15:34:51 +03:00
Yuchen Liang
bee04b8a69 pageserver: add direct io config to virtual file (#9214)
## Problem
We need a way to incrementally switch to direct IO. During the rollout
we might want to switch to O_DIRECT on image and delta layer read path
first before others.

## Summary of changes
- Revisited and simplified direct io config in `PageserverConf`. 
- We could add a fallback mode for open, but for read there isn't a
reasonable alternative (without creating another buffered virtual file).
- Added a wrapper around `VirtualFile`, current implementation become
`VirtualFileInner`
- Use `open_v2`, `create_v2`, `open_with_options_v2` when we want to use
the IO mode specified in PS config.
- Once we onboard all IO through VirtualFile using this new API, we will
delete the old code path.
- Make io mode live configurable for benchmarking.
- Only guaranteed for files opened after the config change, so do it
before the experiment.

As an example, we are using `open_v2` with
`virtual_file::IoMode::Direct` in
https://github.com/neondatabase/neon/pull/9169

We also remove `io_buffer_alignment` config in
a04cfd754b and use it as a compile time
constant. This way we don't have to carry the alignment around or make
frequent call to retrieve this information from the static variable.

Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-09 08:33:07 -04:00
Anastasia Lubennikova
63e7fab990 Add /installed_extensions endpoint to collect statistics about extension usage. (#8917)
Add /installed_extensions endpoint to collect
statistics about extension usage.
It returns a list of installed extensions in the format:

```json
{
  "extensions": [
    {
      "extname": "extension_name",
      "versions": ["1.0", "1.1"],
      "n_databases": 5,
    }
  ]
}
```

---------

Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
2024-10-09 13:32:13 +01:00
Arseny Sher
a181392738 safekeeper: add evicted_timelines gauge. (#9318)
showing total number of evicted timelines.
2024-10-09 14:40:30 +03:00
Alexander Bayandin
fc7397122c test_runner: fix path to tpc-h queries (#9327)
## Problem

The path to TPC-H queries was incorrectly changed in #9306.
This path is used for `test_tpch` parameterization, so all perf tests
started to fail:

```
==================================== ERRORS ====================================
__________ ERROR collecting test_runner/performance/test_perf_olap.py __________
test_runner/performance/test_perf_olap.py:205: in <module>
    @pytest.mark.parametrize("query", tpch_queuies())
test_runner/performance/test_perf_olap.py:196: in tpch_queuies
    assert queries_dir.exists(), f"TPC-H queries dir not found: {queries_dir}"
E   AssertionError: TPC-H queries dir not found: /__w/neon/neon/test_runner/performance/performance/tpc-h/queries
E   assert False
E    +  where False = <bound method Path.exists of PosixPath('/__w/neon/neon/test_runner/performance/performance/tpc-h/queries')>()
E    +    where <bound method Path.exists of PosixPath('/__w/neon/neon/test_runner/performance/performance/tpc-h/queries')> = PosixPath('/__w/neon/neon/test_runner/performance/performance/tpc-h/queries').exists
```

## Summary of changes
- Fix the path to tpc-h queries
2024-10-09 12:11:06 +01:00
Vlad Lazar
cc599e23c1 storcon: make observed state updates more granular (#9276)
## Problem

Previously, observed state updates from the reconciler may have
clobbered inline changes made to the observed state by other code paths.

## Summary of changes

Model observed state changes from reconcilers as deltas. This means that
we only update what has changed. Handling for node going off-line concurrently
during the reconcile is also added: set observed state to None in such cases to
respect the convention.

Closes https://github.com/neondatabase/neon/issues/9124
2024-10-09 11:53:29 +01:00
Folke Behrens
54d1185789 proxy: Unalias hyper1 and replace one use of hyper0 in test (#9324)
Leaves one final use of hyper0 in proxy for the health service,
which requires some coordinated effort with other services.
2024-10-09 12:44:17 +02:00
Heikki Linnakangas
8a138db8b7 tests: Reduce noise from logging renamed files (#9315)
Instead of printing the full absolute path for every file, print just
the filenames.

Before:

    2024-10-08 13:19:39.98 INFO [test_pageserver_generations.py:669] Found file /home/heikki/git-sandbox/neon/test_output/test_upgrade_generationless_local_file_paths[debug-pg16]/repo/pageserver_1/tenants/0c04a8df7691a367ad0bb1cc1373ba4d/timelines/f41022551e5f96ce8dbefb9b5d35ab45/000000067F0000000100000A8D0100000000-000000067F0000000100000AC10000000002__00000000014F16F0-v1-00000001
    2024-10-08 13:19:39.99 INFO [test_pageserver_generations.py:673] Renamed /home/heikki/git-sandbox/neon/test_output/test_upgrade_generationless_local_file_paths[debug-pg16]/repo/pageserver_1/tenants/0c04a8df7691a367ad0bb1cc1373ba4d/timelines/f41022551e5f96ce8dbefb9b5d35ab45/000000067F0000000100000A8D0100000000-000000067F0000000100000AC10000000002__00000000014F16F0-v1-00000001 -> /home/heikki/git-sandbox/neon/test_output/test_upgrade_generationless_local_file_paths[debug-pg16]/repo/pageserver_1/tenants/0c04a8df7691a367ad0bb1cc1373ba4d/timelines/f41022551e5f96ce8dbefb9b5d35ab45/000000067F0000000100000A8D0100000000-000000067F0000000100000AC10000000002__00000000014F16F0

After:

    2024-10-08 13:24:39.726 INFO [test_pageserver_generations.py:667] Renaming files in /home/heikki/git-sandbox/neon/test_output/test_upgrade_generationless_local_file_paths[debug-pg16]/repo/pageserver_1/tenants/3439538816c520adecc541cc8b1de21c/timelines/6a7be8ee707b355de48dd91b326d6ae1
    2024-10-08 13:24:39.728 INFO [test_pageserver_generations.py:673] Renamed
000000067F0000000100000A8D0100000000-000000067F0000000100000AC10000000002__00000000014F16F0-v1-00000001 -> 000000067F0000000100000A8D0100000000-000000067F0000000100000AC10000000002__00000000014F16F0
2024-10-09 10:55:56 +01:00
Erik Grinaker
211970f0e0 remote_storage: add DownloadOpts::byte_(start|end) (#9293)
`download_byte_range()` is basically a copy of `download()` with an
additional option passed to the backend SDKs. This can cause these code
paths to diverge, and prevents combining various options.

This patch adds `DownloadOpts::byte_(start|end)` and move byte range
handling into `download()`.
2024-10-09 10:29:06 +01:00
Heikki Linnakangas
f87f5a383e tests: Remove redundant log lines when starting an endpoint (#9316)
The "Starting postgres endpoint <name>" message is not needed, because
the neon_cli.py prints the neon_local command line used to start the
endpoint. That contains the same information. The "Postgres startup took
XX seconds" message is not very useful because no one pays attention to
those in the python test logs when things are going smoothly, and if you
do wonder about the startup speed, the same information and more can be
found in the compute log.

Before:

    2024-10-07 22:32:27.794 INFO [neon_fixtures.py:3492] Starting postgres endpoint ep-1
    2024-10-07 22:32:27.794 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local endpoint start --safekeepers 1 ep-1"
    2024-10-07 22:32:27.901 INFO [neon_fixtures.py:3690] Postgres startup took 0.11398935317993164 seconds

After:

    2024-10-07 22:32:27.794 INFO [neon_cli.py:73] Running command "/tmp/neon/bin/neon_local endpoint start --safekeepers 1 ep-1"
2024-10-09 09:58:50 +01:00
Arpad Müller
e8ae37652b Add timeline offload mechanism (#8907)
Implements an initial mechanism for offloading of archived timelines.

Offloading is implemented as specified in the RFC.

For now, there is no persistence, so a restart of the pageserver will
retrigger downloads until the timeline is offloaded again.

We trigger offloading in the compaction loop because we need the signal
for whether compaction is done and everything has been uploaded or not.

Part of #8088
2024-10-09 01:33:39 +02:00
Tristan Partin
5bd8e2363a Enable all pyupgrade checks in ruff
This will help to keep us from using deprecated Python features going
forward.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-08 14:32:26 -05:00
Vlad Lazar
618680c299 storcon: apply all node status changes before handling transitions (#9281)
## Problem

When a node goes offline, we trigger reconciles to migrate shards away
from it. If multiple nodes go offline at the same time, we handled them in
sequence. Hence, we might migrate shards from the first offline node to the second
offline node and increase the unavailability period.

## Summary of changes

Refactor heartbeat delta handling to:
1. Update in memory state for all nodes first
2. Handle availability transitions one by one (we have full picture for each node after (1))

Closes https://github.com/neondatabase/neon/issues/9126
2024-10-08 17:55:25 +01:00
Alexander Bayandin
baf27ba6a3 Fix compiler warnings on macOS (#9319)
## Problem

On macOS:
```
/Users/runner/work/neon/neon//pgxn/neon/file_cache.c:623:19: error: variable 'has_remaining_pages' is used uninitialized whenever 'for' loop exits because its condition is false [-Werror,-Wsometimes-uninitialized]
```

## Summary of changes
- Initialise `has_remaining_pages` with `false`
2024-10-08 17:34:35 +01:00
Tristan Partin
16417d919d Remove get_self_dir()
It didn't serve much value, and was only used twice.
Path(__file__).parent is a pretty easy invocation to use.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-08 08:57:11 -05:00
Heikki Linnakangas
18b97150b2 Remove non-existent entries from .dockerignore (#9209) 2024-10-08 14:55:24 +03:00
Heikki Linnakangas
17c59ed786 Don't override CFLAGS when building neon extension
If you override CFLAGS, you also override any flags that PostgreSQL
configure script had picked. That includes many options that enable
extra compiler warnings, like '-Wall', '-Wmissing-prototypes', and so
forth. The override was added in commit 171385ac14, but the intention
of that was to be *more* strict, by enabling '-Werror', not less
strict. The proper way of setting '-Werror', as documented in the docs
and mentioned in PR #2405, is to set COPT='-Werror', but leave CFLAGS
alone.

All the compiler warnings with the standard PostgreSQL flags have now
been fixed, so we can do this without adding noise.

Part of the cleanup issue #9217.
2024-10-07 23:49:33 +03:00
Heikki Linnakangas
d7b960c9b5 Silence compiler warning about using variable uninitialized
It's not a bug, the variable is initialized when it's used, but the
compiler isn't smart enough to see that through all the conditions.

Part of the cleanup issue #9217.
2024-10-07 23:49:31 +03:00
Heikki Linnakangas
2ff6d2b6b5 Silence compiler warning about variable only used in assertions
Part of the cleanup issue #9217.
2024-10-07 23:49:29 +03:00
Heikki Linnakangas
30f7fbc88d Add pg_attribute_printf to WalProposerLibLog, per gcc's suggestion
/pgxn/neon/walproposer_compat.c:192:9: warning: function ‘WalProposerLibLog’ might be a candidate for ‘gnu_printf’ format attribute [-Wsuggest-attribute=format]
      192 |         vsnprintf(buf, sizeof(buf), fmt, args);
          |         ^~~~~~~~~
2024-10-07 23:49:27 +03:00
Heikki Linnakangas
09f2000f91 Silence warnings about shadowed local variables
Part of the cleanup issue #9217.
2024-10-07 23:49:24 +03:00
Heikki Linnakangas
e553ca9e4f Silence warnings about mixed declarations and code
The warning:

    warning: ISO C90 forbids mixed declarations and code [-Wdeclaration-after-statement]

It's PostgreSQL project style to stick to the old C90 style.
(Alternatively, we could disable it for our extension.)

Part of the cleanup issue #9217.
2024-10-07 23:49:22 +03:00
Heikki Linnakangas
0a80dbce83 neon_write() function is not used on v17
ifdef it out on v17, to silence compiler warning.

Part of the cleanup issue #9217.
2024-10-07 23:49:20 +03:00
Heikki Linnakangas
e763256448 Fix warnings about missing function prototypes
Prototypes for neon_writev(), neon_readv(), and neon_regisersync()
were missing. But instead of adding the missing prototypes, mark all
the smgr functions 'static'.

Part of the cleanup issue #9217.
2024-10-07 23:49:18 +03:00
Heikki Linnakangas
129d4480bb Move "/* fallthrough */" comments so that GCC recognizes them
This silences warnings about implicit fallthroughs.

Part of the cleanup issue #9217.
2024-10-07 23:49:16 +03:00
Heikki Linnakangas
776df963ba Fix function prototypes
Silences these compiler warnings:

    /pgxn/neon_walredo/walredoproc.c:452:1: warning: ‘CreateFakeSharedMemoryAndSemaphores’ was used with no prototype before its definition [-Wmissing-prototypes]
      452 | CreateFakeSharedMemoryAndSemaphores()
          | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    /pgxn/neon/walproposer_pg.c:541:1: warning: no previous prototype for ‘GetWalpropShmemState’ [-Wmissing-prototypes]
      541 | GetWalpropShmemState()
          | ^~~~~~~~~~~~~~~~~~~~

Part of the cleanup issue #9217.
2024-10-07 23:49:13 +03:00
Heikki Linnakangas
11dc5feb36 Remove unused static function
In v16 merge, we copied much of heap RMGR, to distinguish vanilla
Postgres heap records from records generated with neon patches, with
the additional CID fields. This function is only used by the
HEAP_TRUNCATE records, however, which we didn't need to copy.

Part of the cleanup issue #9217.
2024-10-07 23:49:11 +03:00
Heikki Linnakangas
dbbe57a837 Remove unused local vars and a prototype for non-existent function
Per compiler warnings. Part of the cleanup issue #9217.
2024-10-07 23:49:09 +03:00
Em Sharnoff
cc29def544 vm-monitor: Ignore LFC in postgres cgroup memory threshold (#8668)
In short: Currently we reserve 75% of memory to the LFC, meaning that if
we scale up to keep postgres using less than 25% of the compute's
memory.

This means that for certain memory-heavy workloads, we end up scaling
much higher than is actually needed — in the worst case, up to 4x,
although in practice it tends not to be quite so bad.

Part of neondatabase/autoscaling#1030.
2024-10-07 21:25:34 +01:00
Arpad Müller
912d47ec02 storage_broker: update hyper and tonic again (#9299)
Update hyper and tonic again in the storage broker, this time with a fix
for the issue that made us revert the update last time.

The first commit is a revert of #9268, the second a fix for the issue.

fixes #9231.
2024-10-07 21:12:13 +02:00
Tristan Partin
6eba29c732 Improve logging on changes in a compute's status
I'm trying to debug a situation with the LR benchmark publisher not
being in the correct state. This should aid in debugging, while just
being generally useful.

PR: https://github.com/neondatabase/neon/pull/9265
Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-07 13:19:48 -04:00
Heikki Linnakangas
99d4c1877b Replace BUFFERTAGS_EQUAL compatibility macro with new-style function (#9294)
In PostgreSQL v16, BUFFERTAGS_EQUAL was replaced with a static inline
macro, BufferTagsEqual. Let's use the new name going forward, and have
backwards-compatibility glue to allow using the new name on v14 and v15,
rather than the other way round. This also makes BufferTagsEquals
consistent with InitBufferTag, for which we were already using the new
name.
2024-10-07 19:49:27 +03:00
Jere Vaara
2272dc8a48 feat(compute_tools): Create JWKS Postgres roles without attributes (#9031)
Requires https://github.com/neondatabase/neon/pull/9086 first to have
`local_proxy_config`. This logic can still be reviewed implementation
wise.

Create JWT Auth functionality related roles without attributes and
`neon_superuser` group.

Read the JWT related roles from `local_proxy_config` `JWKS` settings and
handle them differently than other console created roles.
2024-10-07 19:37:32 +03:00
301 changed files with 3995 additions and 2809 deletions

View File

@@ -5,9 +5,7 @@
!Cargo.toml
!Makefile
!rust-toolchain.toml
!scripts/combine_control_files.py
!scripts/ninstall.sh
!vm-cgconfig.conf
!docker-compose/run-tests.sh
# Directories
@@ -17,15 +15,12 @@
!compute_tools/
!control_plane/
!libs/
!neon_local/
!pageserver/
!patches/
!pgxn/
!proxy/
!storage_scrubber/
!safekeeper/
!storage_broker/
!storage_controller/
!trace/
!vendor/postgres-*/
!workspace_hack/

View File

@@ -0,0 +1,41 @@
name: Report Workflow Stats
on:
workflow_run:
workflows:
- Add `external` label to issues and PRs created by external users
- Benchmarking
- Build and Test
- Build and Test Locally
- Build build-tools image
- Check Permissions
- Check build-tools image
- Check neon with extra platform builds
- Cloud Regression Test
- Create Release Branch
- Handle `approved-for-ci-run` label
- Lint GitHub Workflows
- Notify Slack channel about upcoming release
- Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region
- Pin build-tools image
- Prepare benchmarking databases by restoring dumps
- Push images to ACR
- Test Postgres client libraries
- Trigger E2E Tests
- cleanup caches by a branch
types: [completed]
jobs:
gh-workflow-stats:
name: Github Workflow Stats
runs-on: ubuntu-22.04
permissions:
actions: read
steps:
- name: Export GH Workflow Stats
uses: fedordikarev/gh-workflow-stats-action@v0.1.2
with:
DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }}
DB_TABLE: "gh_workflow_stats_neon"
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_RUN_ID: ${{ github.event.workflow_run.id }}

264
Cargo.lock generated
View File

@@ -666,34 +666,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core 0.3.4",
"bitflags 1.3.2",
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.30",
"itoa",
"matchit 0.7.0",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 0.1.2",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum"
version = "0.7.5"
@@ -701,7 +673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [
"async-trait",
"axum-core 0.4.5",
"axum-core",
"base64 0.21.1",
"bytes",
"futures-util",
@@ -731,23 +703,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.4.5"
@@ -971,7 +926,7 @@ dependencies = [
"clang-sys",
"itertools 0.12.1",
"log",
"prettyplease 0.2.17",
"prettyplease",
"proc-macro2",
"quote",
"regex",
@@ -2454,15 +2409,6 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -2657,14 +2603,15 @@ dependencies = [
[[package]]
name = "hyper-timeout"
version = "0.4.1"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 0.14.30",
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
"tower-service",
]
[[package]]
@@ -3470,7 +3417,7 @@ dependencies = [
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry_sdk",
"prost 0.13.3",
"prost",
"reqwest 0.12.4",
"thiserror",
]
@@ -3483,8 +3430,8 @@ checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9"
dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"prost 0.13.3",
"tonic 0.12.3",
"prost",
"tonic",
]
[[package]]
@@ -4178,16 +4125,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "prettyplease"
version = "0.1.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86"
dependencies = [
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "prettyplease"
version = "0.2.17"
@@ -4258,16 +4195,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "prost"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive 0.11.9",
]
[[package]]
name = "prost"
version = "0.13.3"
@@ -4275,42 +4202,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f"
dependencies = [
"bytes",
"prost-derive 0.13.3",
"prost-derive",
]
[[package]]
name = "prost-build"
version = "0.11.9"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.4.1",
"itertools 0.10.5",
"lazy_static",
"heck 0.5.0",
"itertools 0.12.1",
"log",
"multimap",
"once_cell",
"petgraph",
"prettyplease 0.1.25",
"prost 0.11.9",
"prettyplease",
"prost",
"prost-types",
"regex",
"syn 1.0.109",
"syn 2.0.52",
"tempfile",
"which",
]
[[package]]
name = "prost-derive"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
@@ -4328,11 +4241,11 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.11.9"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670"
dependencies = [
"prost 0.11.9",
"prost",
]
[[package]]
@@ -5094,6 +5007,21 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
dependencies = [
"log",
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.2"
@@ -5119,6 +5047,19 @@ dependencies = [
"security-framework",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a"
dependencies = [
"openssl-probe",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
@@ -5194,6 +5135,7 @@ dependencies = [
"fail",
"futures",
"hex",
"http 1.1.0",
"humantime",
"hyper 0.14.30",
"metrics",
@@ -5750,19 +5692,22 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-stream",
"bytes",
"clap",
"const_format",
"futures",
"futures-core",
"futures-util",
"http-body-util",
"humantime",
"hyper 0.14.30",
"hyper 1.4.1",
"hyper-util",
"metrics",
"once_cell",
"parking_lot 0.12.1",
"prost 0.11.9",
"prost",
"tokio",
"tonic 0.9.2",
"tonic",
"tonic-build",
"tracing",
"utils",
@@ -6306,6 +6251,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.7",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.16"
@@ -6397,29 +6353,30 @@ dependencies = [
[[package]]
name = "tonic"
version = "0.9.2"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a"
checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
dependencies = [
"async-stream",
"async-trait",
"axum 0.6.20",
"base64 0.21.1",
"axum",
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"h2 0.3.26",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.30",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost 0.11.9",
"rustls-native-certs 0.6.2",
"rustls-pemfile 1.0.2",
"prost",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
"socket2",
"tokio",
"tokio-rustls 0.24.0",
"tokio-rustls 0.26.0",
"tokio-stream",
"tower",
"tower-layer",
@@ -6428,37 +6385,17 @@ dependencies = [
]
[[package]]
name = "tonic"
name = "tonic-build"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"percent-encoding",
"pin-project",
"prost 0.13.3",
"tokio-stream",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07"
dependencies = [
"prettyplease 0.1.25",
"prettyplease",
"proc-macro2",
"prost-build",
"prost-types",
"quote",
"syn 1.0.109",
"syn 2.0.52",
]
[[package]]
@@ -6864,7 +6801,7 @@ name = "vm_monitor"
version = "0.1.0"
dependencies = [
"anyhow",
"axum 0.7.5",
"axum",
"cgroups-rs",
"clap",
"futures",
@@ -7095,18 +7032,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "whoami"
version = "1.5.1"
@@ -7335,9 +7260,10 @@ version = "0.1.0"
dependencies = [
"ahash",
"anyhow",
"axum",
"axum-core",
"base64 0.21.1",
"base64ct",
"bitflags 2.4.1",
"bytes",
"camino",
"cc",
@@ -7365,7 +7291,6 @@ dependencies = [
"hyper 1.4.1",
"hyper-util",
"indexmap 1.9.3",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"libc",
@@ -7377,15 +7302,15 @@ dependencies = [
"num-traits",
"once_cell",
"parquet",
"prettyplease",
"proc-macro2",
"prost 0.11.9",
"prost",
"quote",
"rand 0.8.5",
"regex",
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
"reqwest 0.12.4",
"rustls 0.21.11",
"scopeguard",
"serde",
"serde_json",
@@ -7401,9 +7326,10 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-rustls 0.24.0",
"tokio-stream",
"tokio-util",
"toml_edit",
"tonic",
"tower",
"tracing",
"tracing-core",

View File

@@ -130,7 +130,7 @@ pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pin-project-lite = "0.2"
procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
prost = "0.13"
rand = "0.8"
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
@@ -178,7 +178,7 @@ tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
toml = "0.8"
toml_edit = "0.22"
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
tonic = {version = "0.12.3", features = ["tls", "tls-roots"]}
tower-service = "0.3.2"
tracing = "0.1"
tracing-error = "0.2"
@@ -246,7 +246,7 @@ criterion = "0.5.1"
rcgen = "0.12"
rstest = "0.18"
camino-tempfile = "1.0.2"
tonic-build = "0.9"
tonic-build = "0.12"
[patch.crates-io]

View File

@@ -168,27 +168,27 @@ postgres-check-%: postgres-%
neon-pg-ext-%: postgres-%
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_walredo/Makefile install
+@echo "Compiling neon_rmgr $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_rmgr/Makefile install
+@echo "Compiling neon_test_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
@@ -220,7 +220,7 @@ neon-pg-clean-ext-%:
walproposer-lib: neon-pg-ext-v17
+@echo "Compiling walproposer-lib"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib
cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
@@ -333,7 +333,7 @@ postgres-%-pgindent: postgres-%-pg-bsd-indent postgres-%-typedefs.list
# Indent pxgn/neon.
.PHONY: neon-pgindent
neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(POSTGRES_INSTALL_DIR)/build/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \

View File

@@ -402,8 +402,7 @@ fn start_postgres(
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
// We got all we need, update the state.
let mut state = compute.state.lock().unwrap();
state.status = ComputeStatus::Init;
compute.state_changed.notify_all();
state.set_status(ComputeStatus::Init, &compute.state_changed);
info!(
"running compute with features: {:?}",

View File

@@ -109,6 +109,18 @@ impl ComputeState {
metrics: ComputeMetrics::default(),
}
}
pub fn set_status(&mut self, status: ComputeStatus, state_changed: &Condvar) {
let prev = self.status;
info!("Changing compute status from {} to {}", prev, status);
self.status = status;
state_changed.notify_all();
}
pub fn set_failed_status(&mut self, err: anyhow::Error, state_changed: &Condvar) {
self.error = Some(format!("{err:?}"));
self.set_status(ComputeStatus::Failed, state_changed);
}
}
impl Default for ComputeState {
@@ -303,15 +315,12 @@ impl ComputeNode {
pub fn set_status(&self, status: ComputeStatus) {
let mut state = self.state.lock().unwrap();
state.status = status;
self.state_changed.notify_all();
state.set_status(status, &self.state_changed);
}
pub fn set_failed_status(&self, err: anyhow::Error) {
let mut state = self.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
self.state_changed.notify_all();
state.set_failed_status(err, &self.state_changed);
}
pub fn get_status(&self) -> ComputeStatus {
@@ -1475,6 +1484,28 @@ LIMIT 100",
info!("Pageserver config changed");
}
}
// Gather info about installed extensions
pub fn get_installed_extensions(&self) -> Result<()> {
let connstr = self.connstr.clone();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create runtime");
let result = rt
.block_on(crate::installed_extensions::get_installed_extensions(
connstr,
))
.expect("failed to get installed extensions");
info!(
"{}",
serde_json::to_string(&result).expect("failed to serialize extensions list")
);
Ok(())
}
}
pub fn forward_termination_signal() {

View File

@@ -24,8 +24,7 @@ fn configurator_main_loop(compute: &Arc<ComputeNode>) {
// Re-check the status after waking up
if state.status == ComputeStatus::ConfigurationPending {
info!("got configuration request");
state.status = ComputeStatus::Configuration;
compute.state_changed.notify_all();
state.set_status(ComputeStatus::Configuration, &compute.state_changed);
drop(state);
let mut new_status = ComputeStatus::Failed;

View File

@@ -165,6 +165,32 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
// get the list of installed extensions
// currently only used in python tests
// TODO: call it from cplane
(&Method::GET, "/installed_extensions") => {
info!("serving /installed_extensions GET request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for extensions request: {:?}",
status
);
error!(msg);
return Response::new(Body::from(msg));
}
let connstr = compute.connstr.clone();
let res = crate::installed_extensions::get_installed_extensions(connstr).await;
match res {
Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())),
Err(e) => render_json_error(
&format!("could not get list of installed extensions: {}", e),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
// download extension files from remote extension storage on demand
(&Method::POST, route) if route.starts_with("/extension_server/") => {
info!("serving {:?} POST request", route);
@@ -288,8 +314,7 @@ async fn handle_configure_request(
return Err((msg, StatusCode::PRECONDITION_FAILED));
}
state.pspec = Some(parsed_spec);
state.status = ComputeStatus::ConfigurationPending;
compute.state_changed.notify_all();
state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed);
drop(state);
info!("set new spec and notified waiters");
}
@@ -362,15 +387,15 @@ async fn handle_terminate_request(compute: &Arc<ComputeNode>) -> Result<(), (Str
}
if state.status != ComputeStatus::Empty && state.status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for termination request: {:?}",
state.status.clone()
"invalid compute status for termination request: {}",
state.status
);
return Err((msg, StatusCode::PRECONDITION_FAILED));
}
state.status = ComputeStatus::TerminationPending;
compute.state_changed.notify_all();
state.set_status(ComputeStatus::TerminationPending, &compute.state_changed);
drop(state);
}
forward_termination_signal();
info!("sent signal and notified waiters");
@@ -384,7 +409,8 @@ async fn handle_terminate_request(compute: &Arc<ComputeNode>) -> Result<(), (Str
while state.status != ComputeStatus::Terminated {
state = c.state_changed.wait(state).unwrap();
info!(
"waiting for compute to become Terminated, current status: {:?}",
"waiting for compute to become {}, current status: {:?}",
ComputeStatus::Terminated,
state.status
);
}

View File

@@ -53,6 +53,20 @@ paths:
schema:
$ref: "#/components/schemas/ComputeInsights"
/installed_extensions:
get:
tags:
- Info
summary: Get installed extensions.
description: ""
operationId: getInstalledExtensions
responses:
200:
description: List of installed extensions
content:
application/json:
schema:
$ref: "#/components/schemas/InstalledExtensions"
/info:
get:
tags:
@@ -395,6 +409,24 @@ components:
- configuration
example: running
InstalledExtensions:
type: object
properties:
extensions:
description: Contains list of installed extensions.
type: array
items:
type: object
properties:
extname:
type: string
versions:
type: array
items:
type: string
n_databases:
type: integer
#
# Errors
#

View File

@@ -0,0 +1,80 @@
use compute_api::responses::{InstalledExtension, InstalledExtensions};
use std::collections::HashMap;
use std::collections::HashSet;
use url::Url;
use anyhow::Result;
use postgres::{Client, NoTls};
use tokio::task;
/// We don't reuse get_existing_dbs() just for code clarity
/// and to make database listing query here more explicit.
///
/// Limit the number of databases to 500 to avoid excessive load.
fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state
let databases = client
.query(
"SELECT datname FROM pg_catalog.pg_database
WHERE datallowconn
AND datconnlimit <> - 2
LIMIT 500",
&[],
)?
.iter()
.map(|row| {
let db: String = row.get("datname");
db
})
.collect();
Ok(databases)
}
/// Connect to every database (see list_dbs above) and get the list of installed extensions.
/// Same extension can be installed in multiple databases with different versions,
/// we only keep the highest and lowest version across all databases.
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {
let mut connstr = connstr.clone();
task::spawn_blocking(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
connstr.set_path(db);
let mut db_client = Client::connect(connstr.as_str(), NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
for (extname, v) in extensions.iter() {
let version = v.to_string();
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
}
}
Ok(InstalledExtensions {
extensions: extensions_map.values().cloned().collect(),
})
})
.await?
}

View File

@@ -15,6 +15,7 @@ pub mod catalog;
pub mod compute;
pub mod disk_quota;
pub mod extension_server;
pub mod installed_extensions;
pub mod local_proxy;
pub mod lsn_lease;
mod migration;

View File

@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::fs::File;
use std::path::Path;
use std::str::FromStr;
@@ -189,6 +190,15 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
let mut xact = client.transaction()?;
let existing_roles: Vec<Role> = get_existing_roles(&mut xact)?;
let mut jwks_roles = HashSet::new();
if let Some(local_proxy) = &spec.local_proxy_config {
for jwks_setting in local_proxy.jwks.iter().flatten() {
for role_name in &jwks_setting.role_names {
jwks_roles.insert(role_name.clone());
}
}
}
// Print a list of existing Postgres roles (only in debug mode)
if span_enabled!(Level::INFO) {
let mut vec = Vec::new();
@@ -308,6 +318,9 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
name.pg_quote()
);
if jwks_roles.contains(name.as_str()) {
query = format!("CREATE ROLE {}", name.pg_quote());
}
info!("running role create query: '{}'", &query);
query.push_str(&role.to_pg_options());
xact.execute(query.as_str(), &[])?;

View File

@@ -1,5 +1,8 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::collections::HashSet;
use std::fmt::Display;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize, Serializer};
@@ -58,6 +61,21 @@ pub enum ComputeStatus {
Terminated,
}
impl Display for ComputeStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ComputeStatus::Empty => f.write_str("empty"),
ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"),
ComputeStatus::Init => f.write_str("init"),
ComputeStatus::Running => f.write_str("running"),
ComputeStatus::Configuration => f.write_str("configuration"),
ComputeStatus::Failed => f.write_str("failed"),
ComputeStatus::TerminationPending => f.write_str("termination-pending"),
ComputeStatus::Terminated => f.write_str("terminated"),
}
}
}
fn rfc3339_serialize<S>(x: &Option<DateTime<Utc>>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
@@ -138,3 +156,15 @@ pub enum ControlPlaneComputeStatus {
// should be able to start with provided spec.
Attached,
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct InstalledExtension {
pub extname: String,
pub versions: HashSet<String>,
pub n_databases: u32, // Number of databases using this extension
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct InstalledExtensions {
pub extensions: Vec<InstalledExtension>,
}

View File

@@ -104,8 +104,7 @@ pub struct ConfigToml {
pub image_compression: ImageCompressionAlgorithm,
pub ephemeral_bytes_per_memory_kb: usize,
pub l0_flush: Option<crate::models::L0FlushConfig>,
pub virtual_file_direct_io: crate::models::virtual_file::DirectIoMode,
pub io_buffer_alignment: usize,
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -388,10 +387,7 @@ impl Default for ConfigToml {
image_compression: (DEFAULT_IMAGE_COMPRESSION),
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
l0_flush: None,
virtual_file_direct_io: crate::models::virtual_file::DirectIoMode::default(),
io_buffer_alignment: DEFAULT_IO_BUFFER_ALIGNMENT,
virtual_file_io_mode: None,
tenant_config: TenantConfigToml::default(),
}
}

View File

@@ -972,8 +972,6 @@ pub struct TopTenantShardsResponse {
}
pub mod virtual_file {
use std::path::PathBuf;
#[derive(
Copy,
Clone,
@@ -994,50 +992,45 @@ pub mod virtual_file {
}
/// Direct IO modes for a pageserver.
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)]
pub enum DirectIoMode {
/// Direct IO disabled (uses usual buffered IO).
#[default]
Disabled,
/// Direct IO disabled (performs checks and perf simulations).
Evaluate {
/// Alignment check level
alignment_check: DirectIoAlignmentCheckLevel,
/// Latency padded for performance simulation.
latency_padding: DirectIoLatencyPadding,
},
/// Direct IO enabled.
Enabled {
/// Actions to perform on alignment error.
on_alignment_error: DirectIoOnAlignmentErrorAction,
},
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
#[repr(u8)]
pub enum IoMode {
/// Uses buffered IO.
Buffered,
/// Uses direct IO, error out if the operation fails.
#[cfg(target_os = "linux")]
Direct,
}
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum DirectIoAlignmentCheckLevel {
#[default]
Error,
Log,
None,
impl IoMode {
pub const fn preferred() -> Self {
Self::Buffered
}
}
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum DirectIoOnAlignmentErrorAction {
Error,
#[default]
FallbackToBuffered,
}
impl TryFrom<u8> for IoMode {
type Error = u8;
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum DirectIoLatencyPadding {
/// Pad virtual file operations with IO to a fake file.
FakeFileRW { path: PathBuf },
#[default]
None,
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (IoMode::Buffered as u8) => IoMode::Buffered,
#[cfg(target_os = "linux")]
v if v == (IoMode::Direct as u8) => IoMode::Direct,
x => return Err(x),
})
}
}
}

View File

@@ -496,26 +496,12 @@ impl RemoteStorage for AzureBlobStorage {
builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()))
}
self.download_for_builder(builder, cancel).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let mut builder = blob_client.get();
let range: Range = if let Some(end_exclusive) = end_exclusive {
(start_inclusive..end_exclusive).into()
} else {
(start_inclusive..).into()
};
builder = builder.range(range);
if let Some((start, end)) = opts.byte_range() {
builder = builder.range(match end {
Some(end) => Range::Range(start..end),
None => Range::RangeFrom(start..),
});
}
self.download_for_builder(builder, cancel).await
}

View File

@@ -19,7 +19,8 @@ mod simulate_failures;
mod support;
use std::{
collections::HashMap, fmt::Debug, num::NonZeroU32, pin::Pin, sync::Arc, time::SystemTime,
collections::HashMap, fmt::Debug, num::NonZeroU32, ops::Bound, pin::Pin, sync::Arc,
time::SystemTime,
};
use anyhow::Context;
@@ -162,11 +163,60 @@ pub struct Listing {
}
/// Options for downloads. The default value is a plain GET.
#[derive(Default)]
pub struct DownloadOpts {
/// If given, returns [`DownloadError::Unmodified`] if the object still has
/// the same ETag (using If-None-Match).
pub etag: Option<Etag>,
/// The start of the byte range to download, or unbounded.
pub byte_start: Bound<u64>,
/// The end of the byte range to download, or unbounded. Must be after the
/// start bound.
pub byte_end: Bound<u64>,
}
impl Default for DownloadOpts {
fn default() -> Self {
Self {
etag: Default::default(),
byte_start: Bound::Unbounded,
byte_end: Bound::Unbounded,
}
}
}
impl DownloadOpts {
/// Returns the byte range with inclusive start and exclusive end, or None
/// if unbounded.
pub fn byte_range(&self) -> Option<(u64, Option<u64>)> {
if self.byte_start == Bound::Unbounded && self.byte_end == Bound::Unbounded {
return None;
}
let start = match self.byte_start {
Bound::Excluded(i) => i + 1,
Bound::Included(i) => i,
Bound::Unbounded => 0,
};
let end = match self.byte_end {
Bound::Excluded(i) => Some(i),
Bound::Included(i) => Some(i + 1),
Bound::Unbounded => None,
};
if let Some(end) = end {
assert!(start < end, "range end {end} at or before start {start}");
}
Some((start, end))
}
/// Returns the byte range as an RFC 2616 Range header value with inclusive
/// bounds, or None if unbounded.
pub fn byte_range_header(&self) -> Option<String> {
self.byte_range()
.map(|(start, end)| (start, end.map(|end| end - 1))) // make end inclusive
.map(|(start, end)| match end {
Some(end) => format!("bytes={start}-{end}"),
None => format!("bytes={start}-"),
})
}
}
/// Storage (potentially remote) API to manage its state.
@@ -257,21 +307,6 @@ pub trait RemoteStorage: Send + Sync + 'static {
cancel: &CancellationToken,
) -> Result<Download, DownloadError>;
/// Streams a given byte range of the remote storage entry contents.
///
/// The returned download stream will obey initial timeout and cancellation signal by erroring
/// on whichever happens first. Only one of the reasons will fail the stream, which is usually
/// enough for `tokio::io::copy_buf` usage. If needed the error can be filtered out.
///
/// Returns the metadata, if any was stored with the file previously.
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError>;
/// Delete a single path from remote storage.
///
/// If the operation fails because of timeout or cancellation, the root cause of the error will be
@@ -425,33 +460,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
}
}
pub async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
match self {
Self::LocalFs(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::AwsS3(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::AzureBlob(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::Unreliable(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
}
}
/// See [`RemoteStorage::delete`]
pub async fn delete(
&self,
@@ -573,20 +581,6 @@ impl GenericRemoteStorage {
})
}
/// Downloads the storage object into the `to_path` provided.
/// `byte_range` could be specified to dowload only a part of the file, if needed.
pub async fn download_storage_object(
&self,
byte_range: Option<(u64, Option<u64>)>,
from: &RemotePath,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
match byte_range {
Some((start, end)) => self.download_byte_range(from, start, end, cancel).await,
None => self.download(from, &DownloadOpts::default(), cancel).await,
}
}
/// The name of the bucket/container/etc.
pub fn bucket_name(&self) -> Option<&str> {
match self {
@@ -660,6 +654,76 @@ impl ConcurrencyLimiter {
mod tests {
use super::*;
/// DownloadOpts::byte_range() should generate (inclusive, exclusive) ranges
/// with optional end bound, or None when unbounded.
#[test]
fn download_opts_byte_range() {
// Consider using test_case or a similar table-driven test framework.
let cases = [
// (byte_start, byte_end, expected)
(Bound::Unbounded, Bound::Unbounded, None),
(Bound::Unbounded, Bound::Included(7), Some((0, Some(8)))),
(Bound::Unbounded, Bound::Excluded(7), Some((0, Some(7)))),
(Bound::Included(3), Bound::Unbounded, Some((3, None))),
(Bound::Included(3), Bound::Included(7), Some((3, Some(8)))),
(Bound::Included(3), Bound::Excluded(7), Some((3, Some(7)))),
(Bound::Excluded(3), Bound::Unbounded, Some((4, None))),
(Bound::Excluded(3), Bound::Included(7), Some((4, Some(8)))),
(Bound::Excluded(3), Bound::Excluded(7), Some((4, Some(7)))),
// 1-sized ranges are fine, 0 aren't and will panic (separate test).
(Bound::Included(3), Bound::Included(3), Some((3, Some(4)))),
(Bound::Included(3), Bound::Excluded(4), Some((3, Some(4)))),
];
for (byte_start, byte_end, expect) in cases {
let opts = DownloadOpts {
byte_start,
byte_end,
..Default::default()
};
let result = opts.byte_range();
assert_eq!(
result, expect,
"byte_start={byte_start:?} byte_end={byte_end:?}"
);
// Check generated HTTP header, which uses an inclusive range.
let expect_header = expect.map(|(start, end)| match end {
Some(end) => format!("bytes={start}-{}", end - 1), // inclusive end
None => format!("bytes={start}-"),
});
assert_eq!(
opts.byte_range_header(),
expect_header,
"byte_start={byte_start:?} byte_end={byte_end:?}"
);
}
}
/// DownloadOpts::byte_range() zero-sized byte range should panic.
#[test]
#[should_panic]
fn download_opts_byte_range_zero() {
DownloadOpts {
byte_start: Bound::Included(3),
byte_end: Bound::Excluded(3),
..Default::default()
}
.byte_range();
}
/// DownloadOpts::byte_range() negative byte range should panic.
#[test]
#[should_panic]
fn download_opts_byte_range_negative() {
DownloadOpts {
byte_start: Bound::Included(3),
byte_end: Bound::Included(2),
..Default::default()
}
.byte_range();
}
#[test]
fn test_object_name() {
let k = RemotePath::new(Utf8Path::new("a/b/c")).unwrap();

View File

@@ -506,54 +506,7 @@ impl RemoteStorage for LocalFs {
return Err(DownloadError::Unmodified);
}
let source = ReaderStream::new(
fs::OpenOptions::new()
.read(true)
.open(&target_path)
.await
.with_context(|| {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?,
);
let metadata = self
.read_storage_metadata(&target_path)
.await
.map_err(DownloadError::Other)?;
let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
let source = crate::support::DownloadStream::new(cancel_or_timeout, source);
Ok(Download {
metadata,
last_modified: file_metadata
.modified()
.map_err(|e| DownloadError::Other(anyhow::anyhow!(e).context("Reading mtime")))?,
etag,
download_stream: Box::pin(source),
})
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
if let Some(end_exclusive) = end_exclusive {
if end_exclusive <= start_inclusive {
return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) is not less than end_exclusive ({end_exclusive:?})")));
};
if start_inclusive == end_exclusive.saturating_sub(1) {
return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) and end_exclusive ({end_exclusive:?}) difference is zero bytes")));
}
}
let target_path = from.with_base(&self.storage_root);
let file_metadata = file_metadata(&target_path).await?;
let mut source = tokio::fs::OpenOptions::new()
let mut file = fs::OpenOptions::new()
.read(true)
.open(&target_path)
.await
@@ -562,31 +515,29 @@ impl RemoteStorage for LocalFs {
})
.map_err(DownloadError::Other)?;
let len = source
.metadata()
.await
.context("query file length")
.map_err(DownloadError::Other)?
.len();
let mut take = file_metadata.len();
if let Some((start, end)) = opts.byte_range() {
if start > 0 {
file.seek(io::SeekFrom::Start(start))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
}
if let Some(end) = end {
take = end - start;
}
}
source
.seek(io::SeekFrom::Start(start_inclusive))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
let source = ReaderStream::new(file.take(take));
let metadata = self
.read_storage_metadata(&target_path)
.await
.map_err(DownloadError::Other)?;
let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive);
let source = ReaderStream::new(source);
let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
let source = crate::support::DownloadStream::new(cancel_or_timeout, source);
let etag = mock_etag(&file_metadata);
Ok(Download {
metadata,
last_modified: file_metadata
@@ -688,7 +639,7 @@ mod fs_tests {
use super::*;
use camino_tempfile::tempdir;
use std::{collections::HashMap, io::Write};
use std::{collections::HashMap, io::Write, ops::Bound};
async fn read_and_check_metadata(
storage: &LocalFs,
@@ -804,10 +755,12 @@ mod fs_tests {
let (first_part_local, second_part_local) = uploaded_bytes.split_at(3);
let first_part_download = storage
.download_byte_range(
.download(
&upload_target,
0,
Some(first_part_local.len() as u64),
&DownloadOpts {
byte_end: Bound::Excluded(first_part_local.len() as u64),
..Default::default()
},
&cancel,
)
.await?;
@@ -823,10 +776,15 @@ mod fs_tests {
);
let second_part_download = storage
.download_byte_range(
.download(
&upload_target,
first_part_local.len() as u64,
Some((first_part_local.len() + second_part_local.len()) as u64),
&DownloadOpts {
byte_start: Bound::Included(first_part_local.len() as u64),
byte_end: Bound::Excluded(
(first_part_local.len() + second_part_local.len()) as u64,
),
..Default::default()
},
&cancel,
)
.await?;
@@ -842,7 +800,14 @@ mod fs_tests {
);
let suffix_bytes = storage
.download_byte_range(&upload_target, 13, None, &cancel)
.download(
&upload_target,
&DownloadOpts {
byte_start: Bound::Included(13),
..Default::default()
},
&cancel,
)
.await?
.download_stream;
let suffix_bytes = aggregate(suffix_bytes).await?;
@@ -850,7 +815,7 @@ mod fs_tests {
assert_eq!(upload_name, suffix);
let all_bytes = storage
.download_byte_range(&upload_target, 0, None, &cancel)
.download(&upload_target, &DownloadOpts::default(), &cancel)
.await?
.download_stream;
let all_bytes = aggregate(all_bytes).await?;
@@ -861,48 +826,26 @@ mod fs_tests {
}
#[tokio::test]
async fn download_file_range_negative() -> anyhow::Result<()> {
let (storage, cancel) = create_storage()?;
#[should_panic(expected = "at or before start")]
async fn download_file_range_negative() {
let (storage, cancel) = create_storage().unwrap();
let upload_name = "upload_1";
let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?;
let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel)
.await
.unwrap();
let start = 1_000_000_000;
let end = start + 1;
match storage
.download_byte_range(
storage
.download(
&upload_target,
start,
Some(end), // exclusive end
&DownloadOpts {
byte_start: Bound::Included(10),
byte_end: Bound::Excluded(10),
..Default::default()
},
&cancel,
)
.await
{
Ok(_) => panic!("Should not allow downloading wrong ranges"),
Err(e) => {
let error_string = e.to_string();
assert!(error_string.contains("zero bytes"));
assert!(error_string.contains(&start.to_string()));
assert!(error_string.contains(&end.to_string()));
}
}
let start = 10000;
let end = 234;
assert!(start > end, "Should test an incorrect range");
match storage
.download_byte_range(&upload_target, start, Some(end), &cancel)
.await
{
Ok(_) => panic!("Should not allow downloading wrong ranges"),
Err(e) => {
let error_string = e.to_string();
assert!(error_string.contains("Invalid range"));
assert!(error_string.contains(&start.to_string()));
assert!(error_string.contains(&end.to_string()));
}
}
Ok(())
.unwrap();
}
#[tokio::test]
@@ -945,10 +888,12 @@ mod fs_tests {
let (first_part_local, _) = uploaded_bytes.split_at(3);
let partial_download_with_metadata = storage
.download_byte_range(
.download(
&upload_target,
0,
Some(first_part_local.len() as u64),
&DownloadOpts {
byte_end: Bound::Excluded(first_part_local.len() as u64),
..Default::default()
},
&cancel,
)
.await?;

View File

@@ -804,34 +804,7 @@ impl RemoteStorage for S3Bucket {
bucket: self.bucket_name.clone(),
key: self.relative_path_to_s3_object(from),
etag: opts.etag.as_ref().map(|e| e.to_string()),
range: None,
},
cancel,
)
.await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// S3 accepts ranges as https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
// and needs both ends to be exclusive
let end_inclusive = end_exclusive.map(|end| end.saturating_sub(1));
let range = Some(match end_inclusive {
Some(end_inclusive) => format!("bytes={start_inclusive}-{end_inclusive}"),
None => format!("bytes={start_inclusive}-"),
});
self.download_object(
GetObjectRequest {
bucket: self.bucket_name.clone(),
key: self.relative_path_to_s3_object(from),
etag: None,
range,
range: opts.byte_range_header(),
},
cancel,
)

View File

@@ -170,28 +170,13 @@ impl RemoteStorage for UnreliableWrapper {
opts: &DownloadOpts,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// Note: We treat any byte range as an "attempt" of the same operation.
// We don't pay attention to the ranges. That's good enough for now.
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner.download(from, opts, cancel).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// Note: We treat any download_byte_range as an "attempt" of the same
// operation. We don't pay attention to the ranges. That's good enough
// for now.
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner
.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
self.delete_inner(path, true, cancel).await
}

View File

@@ -2,6 +2,7 @@ use anyhow::Context;
use camino::Utf8Path;
use futures::StreamExt;
use remote_storage::{DownloadError, DownloadOpts, ListingMode, ListingObject, RemotePath};
use std::ops::Bound;
use std::sync::Arc;
use std::{collections::HashSet, num::NonZeroU32};
use test_context::test_context;
@@ -293,7 +294,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Full range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 0, Some(len as u64), &cancel)
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(0),
byte_end: Bound::Excluded(len as u64),
..Default::default()
},
&cancel,
)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
@@ -301,7 +310,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// partial range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 4, Some(10), &cancel)
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(4),
byte_end: Bound::Excluded(10),
..Default::default()
},
&cancel,
)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..10]);
@@ -309,7 +326,15 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// partial range (end beyond real end)
let dl = ctx
.client
.download_byte_range(&path, 8, Some(len as u64 * 100), &cancel)
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(8),
byte_end: Bound::Excluded(len as u64 * 100),
..Default::default()
},
&cancel,
)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[8..]);
@@ -317,7 +342,14 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Partial range (end unspecified)
let dl = ctx
.client
.download_byte_range(&path, 4, None, &cancel)
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(4),
..Default::default()
},
&cancel,
)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..]);
@@ -325,7 +357,14 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Full range (end unspecified)
let dl = ctx
.client
.download_byte_range(&path, 0, None, &cancel)
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(0),
..Default::default()
},
&cancel,
)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);

View File

@@ -79,8 +79,7 @@ pub struct Config {
/// memory.
///
/// The default value of `0.15` means that we *guarantee* sending upscale requests if the
/// cgroup is using more than 85% of total memory (even if we're *not* separately reserving
/// memory for the file cache).
/// cgroup is using more than 85% of total memory.
cgroup_min_overhead_fraction: f64,
cgroup_downscale_threshold_buffer_bytes: u64,
@@ -97,24 +96,12 @@ impl Default for Config {
}
impl Config {
fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 {
// If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup,
// and thus be non-reclaimable, so we should allow for additional memory usage.
//
// If the file cache sits on disk, our desired stable system state is for it to be fully
// page cached (its contents should only be paged to/from disk in situations where we can't
// upscale fast enough). Page-cached memory is reclaimable, so we need to lower the
// threshold for non-reclaimable memory so we scale up *before* the kernel starts paging
// out the file cache.
let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size);
// Even if we're not separately making room for the file cache (if it's in tmpfs), we still
// want our threshold to be met gracefully instead of letting postgres get OOM-killed.
fn cgroup_threshold(&self, total_mem: u64) -> u64 {
// We want our threshold to be met gracefully instead of letting postgres get OOM-killed
// (or if there's room, spilling to swap).
// So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory
// remaining above the threshold.
let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64;
memory_remaining_for_cgroup.min(max_threshold)
(total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64
}
}
@@ -149,11 +136,6 @@ impl Runner {
let mem = get_total_system_memory();
let mut file_cache_disk_size = 0;
// We need to process file cache initialization before cgroup initialization, so that the memory
// allocated to the file cache is appropriately taken into account when we decide the cgroup's
// memory limits.
if let Some(connstr) = &args.pgconnstr {
info!("initializing file cache");
let config = FileCacheConfig::default();
@@ -184,7 +166,6 @@ impl Runner {
info!("file cache size actually got set to {actual_size}")
}
file_cache_disk_size = actual_size;
state.filecache = Some(file_cache);
}
@@ -207,7 +188,7 @@ impl Runner {
cgroup.watch(hist_tx).await
});
let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size);
let threshold = state.config.cgroup_threshold(mem);
info!(threshold, "set initial cgroup threshold",);
state.cgroup = Some(CgroupState {
@@ -259,9 +240,7 @@ impl Runner {
return Ok((false, status.to_owned()));
}
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, expected_file_cache_size);
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
let current = last_history.avg_non_reclaimable;
@@ -282,13 +261,11 @@ impl Runner {
// The downscaling has been approved. Downscale the file cache, then the cgroup.
let mut status = vec![];
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let actual_usage = file_cache
.set_file_cache_size(expected_file_cache_size)
.await
.context("failed to set file cache size")?;
file_cache_disk_size = actual_usage;
let message = format!(
"set file cache size to {} MiB",
bytes_to_mebibytes(actual_usage),
@@ -298,9 +275,7 @@ impl Runner {
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
let message = format!(
"set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB",
@@ -329,7 +304,6 @@ impl Runner {
let new_mem = resources.mem;
let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes);
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory);
info!(
@@ -342,7 +316,6 @@ impl Runner {
.set_file_cache_size(expected_usage)
.await
.context("failed to set file cache size")?;
file_cache_disk_size = actual_usage;
if actual_usage != expected_usage {
warn!(
@@ -354,9 +327,7 @@ impl Runner {
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
info!(
"set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB",

View File

@@ -164,11 +164,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let conf: &'static PageServerConf = Box::leak(Box::new(
pageserver::config::PageServerConf::dummy_conf(temp_dir.path().to_path_buf()),
));
virtual_file::init(
16384,
virtual_file::io_engine_for_bench(),
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
virtual_file::init(16384, virtual_file::io_engine_for_bench());
page_cache::init(conf.page_cache_size);
{

View File

@@ -540,10 +540,13 @@ impl Client {
.map_err(Error::ReceiveBody)
}
/// Configs io buffer alignment at runtime.
pub async fn put_io_alignment(&self, align: usize) -> Result<()> {
let uri = format!("{}/v1/io_alignment", self.mgmt_api_endpoint);
self.request(Method::PUT, uri, align)
/// Configs io mode at runtime.
pub async fn put_io_mode(
&self,
mode: &pageserver_api::models::virtual_file::IoMode,
) -> Result<()> {
let uri = format!("{}/v1/io_mode", self.mgmt_api_endpoint);
self.request(Method::PUT, uri, mode)
.await?
.json()
.await

View File

@@ -152,11 +152,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> {
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
// Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree.
pageserver::virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let mut total_delta_layers = 0usize;

View File

@@ -59,7 +59,7 @@ pub(crate) enum LayerCmd {
async fn read_delta_file(path: impl AsRef<Path>, ctx: &RequestContext) -> Result<()> {
let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path");
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, 1);
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
page_cache::init(100);
let file = VirtualFile::open(path, ctx).await?;
let file_id = page_cache::next_file_id();
@@ -190,11 +190,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> {
new_tenant_id,
new_timeline_id,
} => {
pageserver::virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);

View File

@@ -26,7 +26,7 @@ use pageserver::{
tenant::{dump_layerfile_from_path, metadata::TimelineMetadata},
virtual_file,
};
use pageserver_api::{config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT, shard::TenantShardId};
use pageserver_api::shard::TenantShardId;
use postgres_ffi::ControlFileData;
use remote_storage::{RemotePath, RemoteStorageConfig};
use tokio_util::sync::CancellationToken;
@@ -205,11 +205,7 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> {
async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> {
// Basic initialization of things that don't change after startup
virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
DEFAULT_IO_BUFFER_ALIGNMENT,
);
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
dump_layerfile_from_path(path, true, &ctx).await

View File

@@ -59,9 +59,9 @@ pub(crate) struct Args {
#[clap(long)]
set_io_engine: Option<pageserver_api::models::virtual_file::IoEngineKind>,
/// Before starting the benchmark, live-reconfigure the pageserver to use specified alignment for io buffers.
/// Before starting the benchmark, live-reconfigure the pageserver to use specified io mode (buffered vs. direct).
#[clap(long)]
set_io_alignment: Option<usize>,
set_io_mode: Option<pageserver_api::models::virtual_file::IoMode>,
targets: Option<Vec<TenantTimelineId>>,
}
@@ -129,8 +129,8 @@ async fn main_impl(
mgmt_api_client.put_io_engine(engine_str).await?;
}
if let Some(align) = args.set_io_alignment {
mgmt_api_client.put_io_alignment(align).await?;
if let Some(mode) = &args.set_io_mode {
mgmt_api_client.put_io_mode(mode).await?;
}
// discover targets

View File

@@ -125,8 +125,7 @@ fn main() -> anyhow::Result<()> {
// after setting up logging, log the effective IO engine choice and read path implementations
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings");
info!(?conf.io_buffer_alignment, "starting with setting for IO buffer alignment");
info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode");
// The tenants directory contains all the pageserver local disk state.
// Create if not exists and make sure all the contents are durable before proceeding.
@@ -168,11 +167,7 @@ fn main() -> anyhow::Result<()> {
let scenario = failpoint_support::init();
// Basic initialization of things that don't change after startup
virtual_file::init(
conf.max_file_descriptors,
conf.virtual_file_io_engine,
conf.io_buffer_alignment,
);
virtual_file::init(conf.max_file_descriptors, conf.virtual_file_io_engine);
page_cache::init(conf.page_cache_size);
start_pageserver(launch_ts, conf).context("Failed to start pageserver")?;

View File

@@ -174,9 +174,7 @@ pub struct PageServerConf {
pub l0_flush: crate::l0_flush::L0FlushConfig,
/// Direct IO settings
pub virtual_file_direct_io: virtual_file::DirectIoMode,
pub io_buffer_alignment: usize,
pub virtual_file_io_mode: virtual_file::IoMode,
}
/// Token for authentication to safekeepers
@@ -325,11 +323,10 @@ impl PageServerConf {
image_compression,
ephemeral_bytes_per_memory_kb,
l0_flush,
virtual_file_direct_io,
virtual_file_io_mode,
concurrent_tenant_warmup,
concurrent_tenant_size_logical_size_queries,
virtual_file_io_engine,
io_buffer_alignment,
tenant_config,
} = config_toml;
@@ -368,8 +365,6 @@ impl PageServerConf {
max_vectored_read_bytes,
image_compression,
ephemeral_bytes_per_memory_kb,
virtual_file_direct_io,
io_buffer_alignment,
// ------------------------------------------------------------
// fields that require additional validation or custom handling
@@ -408,6 +403,7 @@ impl PageServerConf {
l0_flush: l0_flush
.map(crate::l0_flush::L0FlushConfig::from)
.unwrap_or_default(),
virtual_file_io_mode: virtual_file_io_mode.unwrap_or(virtual_file::IoMode::preferred()),
};
// ------------------------------------------------------------

View File

@@ -17,6 +17,7 @@ use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::models::virtual_file::IoMode;
use pageserver_api::models::AuxFilePolicy;
use pageserver_api::models::DownloadRemoteLayersTaskSpawnRequest;
use pageserver_api::models::IngestAuxFilesRequest;
@@ -703,6 +704,8 @@ async fn timeline_archival_config_handler(
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
let request_data: TimelineArchivalConfigRequest = json_request(&mut request).await?;
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let state = get_state(&request);
@@ -713,7 +716,7 @@ async fn timeline_archival_config_handler(
.get_attached_tenant_shard(tenant_shard_id)?;
tenant
.apply_timeline_archival_config(timeline_id, request_data.state)
.apply_timeline_archival_config(timeline_id, request_data.state, ctx)
.await?;
Ok::<_, ApiError>(())
}
@@ -2379,17 +2382,13 @@ async fn put_io_engine_handler(
json_response(StatusCode::OK, ())
}
async fn put_io_alignment_handler(
async fn put_io_mode_handler(
mut r: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
check_permission(&r, None)?;
let align: usize = json_request(&mut r).await?;
crate::virtual_file::set_io_buffer_alignment(align).map_err(|align| {
ApiError::PreconditionFailed(
format!("Requested io alignment ({align}) is not a power of two").into(),
)
})?;
let mode: IoMode = json_request(&mut r).await?;
crate::virtual_file::set_io_mode(mode);
json_response(StatusCode::OK, ())
}
@@ -3080,9 +3079,7 @@ pub fn make_router(
|r| api_handler(r, timeline_collect_keyspace),
)
.put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler))
.put("/v1/io_alignment", |r| {
api_handler(r, put_io_alignment_handler)
})
.put("/v1/io_mode", |r| api_handler(r, put_io_mode_handler))
.put(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/force_aux_policy_switch",
|r| api_handler(r, force_aux_policy_switch_handler),

View File

@@ -38,6 +38,7 @@ use std::future::Future;
use std::sync::Weak;
use std::time::SystemTime;
use storage_broker::BrokerClientChannel;
use timeline::offload::offload_timeline;
use tokio::io::BufReader;
use tokio::sync::watch;
use tokio::task::JoinSet;
@@ -287,9 +288,13 @@ pub struct Tenant {
/// During timeline creation, we first insert the TimelineId to the
/// creating map, then `timelines`, then remove it from the creating map.
/// **Lock order**: if acquring both, acquire`timelines` before `timelines_creating`
/// **Lock order**: if acquiring both, acquire`timelines` before `timelines_creating`
timelines_creating: std::sync::Mutex<HashSet<TimelineId>>,
/// Possibly offloaded and archived timelines
/// **Lock order**: if acquiring both, acquire`timelines` before `timelines_offloaded`
timelines_offloaded: Mutex<HashMap<TimelineId, Arc<OffloadedTimeline>>>,
// This mutex prevents creation of new timelines during GC.
// Adding yet another mutex (in addition to `timelines`) is needed because holding
// `timelines` mutex during all GC iteration
@@ -484,6 +489,65 @@ impl WalRedoManager {
}
}
pub struct OffloadedTimeline {
pub tenant_shard_id: TenantShardId,
pub timeline_id: TimelineId,
pub ancestor_timeline_id: Option<TimelineId>,
// TODO: once we persist offloaded state, make this lazily constructed
pub remote_client: Arc<RemoteTimelineClient>,
/// Prevent two tasks from deleting the timeline at the same time. If held, the
/// timeline is being deleted. If 'true', the timeline has already been deleted.
pub delete_progress: Arc<tokio::sync::Mutex<DeleteTimelineFlow>>,
}
impl OffloadedTimeline {
fn from_timeline(timeline: &Timeline) -> Self {
Self {
tenant_shard_id: timeline.tenant_shard_id,
timeline_id: timeline.timeline_id,
ancestor_timeline_id: timeline.get_ancestor_timeline_id(),
remote_client: timeline.remote_client.clone(),
delete_progress: timeline.delete_progress.clone(),
}
}
}
#[derive(Clone)]
pub enum TimelineOrOffloaded {
Timeline(Arc<Timeline>),
Offloaded(Arc<OffloadedTimeline>),
}
impl TimelineOrOffloaded {
pub fn tenant_shard_id(&self) -> TenantShardId {
match self {
TimelineOrOffloaded::Timeline(timeline) => timeline.tenant_shard_id,
TimelineOrOffloaded::Offloaded(offloaded) => offloaded.tenant_shard_id,
}
}
pub fn timeline_id(&self) -> TimelineId {
match self {
TimelineOrOffloaded::Timeline(timeline) => timeline.timeline_id,
TimelineOrOffloaded::Offloaded(offloaded) => offloaded.timeline_id,
}
}
pub fn delete_progress(&self) -> &Arc<tokio::sync::Mutex<DeleteTimelineFlow>> {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
}
}
pub fn remote_client(&self) -> &Arc<RemoteTimelineClient> {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.remote_client,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.remote_client,
}
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum GetTimelineError {
#[error("Timeline is shutting down")]
@@ -1406,52 +1470,192 @@ impl Tenant {
}
}
pub(crate) async fn apply_timeline_archival_config(
&self,
fn check_to_be_archived_has_no_unarchived_children(
timeline_id: TimelineId,
state: TimelineArchivalState,
timelines: &std::sync::MutexGuard<'_, HashMap<TimelineId, Arc<Timeline>>>,
) -> Result<(), TimelineArchivalError> {
let children: Vec<TimelineId> = timelines
.iter()
.filter_map(|(id, entry)| {
if entry.get_ancestor_timeline_id() != Some(timeline_id) {
return None;
}
if entry.is_archived() == Some(true) {
return None;
}
Some(*id)
})
.collect();
if !children.is_empty() {
return Err(TimelineArchivalError::HasUnarchivedChildren(children));
}
Ok(())
}
fn check_ancestor_of_to_be_unarchived_is_not_archived(
ancestor_timeline_id: TimelineId,
timelines: &std::sync::MutexGuard<'_, HashMap<TimelineId, Arc<Timeline>>>,
offloaded_timelines: &std::sync::MutexGuard<
'_,
HashMap<TimelineId, Arc<OffloadedTimeline>>,
>,
) -> Result<(), TimelineArchivalError> {
let has_archived_parent =
if let Some(ancestor_timeline) = timelines.get(&ancestor_timeline_id) {
ancestor_timeline.is_archived() == Some(true)
} else if offloaded_timelines.contains_key(&ancestor_timeline_id) {
true
} else {
error!("ancestor timeline {ancestor_timeline_id} not found");
if cfg!(debug_assertions) {
panic!("ancestor timeline {ancestor_timeline_id} not found");
}
return Err(TimelineArchivalError::NotFound);
};
if has_archived_parent {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline_id,
));
}
Ok(())
}
fn check_to_be_unarchived_timeline_has_no_archived_parent(
timeline: &Arc<Timeline>,
) -> Result<(), TimelineArchivalError> {
if let Some(ancestor_timeline) = timeline.ancestor_timeline() {
if ancestor_timeline.is_archived() == Some(true) {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline.timeline_id,
));
}
}
Ok(())
}
/// Loads the specified (offloaded) timeline from S3 and attaches it as a loaded timeline
async fn unoffload_timeline(
self: &Arc<Self>,
timeline_id: TimelineId,
ctx: RequestContext,
) -> Result<Arc<Timeline>, TimelineArchivalError> {
let cancel = self.cancel.clone();
let timeline_preload = self
.load_timeline_metadata(timeline_id, self.remote_storage.clone(), cancel)
.await;
let index_part = match timeline_preload.index_part {
Ok(index_part) => {
debug!("remote index part exists for timeline {timeline_id}");
index_part
}
Err(DownloadError::NotFound) => {
error!(%timeline_id, "index_part not found on remote");
return Err(TimelineArchivalError::NotFound);
}
Err(e) => {
// Some (possibly ephemeral) error happened during index_part download.
warn!(%timeline_id, "Failed to load index_part from remote storage, failed creation? ({e})");
return Err(TimelineArchivalError::Other(
anyhow::Error::new(e).context("downloading index_part from remote storage"),
));
}
};
let index_part = match index_part {
MaybeDeletedIndexPart::IndexPart(index_part) => index_part,
MaybeDeletedIndexPart::Deleted(_index_part) => {
info!("timeline is deleted according to index_part.json");
return Err(TimelineArchivalError::NotFound);
}
};
let remote_metadata = index_part.metadata.clone();
let timeline_resources = self.build_timeline_resources(timeline_id);
self.load_remote_timeline(
timeline_id,
index_part,
remote_metadata,
timeline_resources,
&ctx,
)
.await
.with_context(|| {
format!(
"failed to load remote timeline {} for tenant {}",
timeline_id, self.tenant_shard_id
)
})?;
let timelines = self.timelines.lock().unwrap();
if let Some(timeline) = timelines.get(&timeline_id) {
let mut offloaded_timelines = self.timelines_offloaded.lock().unwrap();
if offloaded_timelines.remove(&timeline_id).is_none() {
warn!("timeline already removed from offloaded timelines");
}
Ok(Arc::clone(timeline))
} else {
warn!("timeline not available directly after attach");
Err(TimelineArchivalError::Other(anyhow::anyhow!(
"timeline not available directly after attach"
)))
}
}
pub(crate) async fn apply_timeline_archival_config(
self: &Arc<Self>,
timeline_id: TimelineId,
new_state: TimelineArchivalState,
ctx: RequestContext,
) -> Result<(), TimelineArchivalError> {
info!("setting timeline archival config");
let timeline = {
// First part: figure out what is needed to do, and do validation
let timeline_or_unarchive_offloaded = 'outer: {
let timelines = self.timelines.lock().unwrap();
let Some(timeline) = timelines.get(&timeline_id) else {
return Err(TimelineArchivalError::NotFound);
let offloaded_timelines = self.timelines_offloaded.lock().unwrap();
let Some(offloaded) = offloaded_timelines.get(&timeline_id) else {
return Err(TimelineArchivalError::NotFound);
};
if new_state == TimelineArchivalState::Archived {
// It's offloaded already, so nothing to do
return Ok(());
}
if let Some(ancestor_timeline_id) = offloaded.ancestor_timeline_id {
Self::check_ancestor_of_to_be_unarchived_is_not_archived(
ancestor_timeline_id,
&timelines,
&offloaded_timelines,
)?;
}
break 'outer None;
};
if state == TimelineArchivalState::Unarchived {
if let Some(ancestor_timeline) = timeline.ancestor_timeline() {
if ancestor_timeline.is_archived() == Some(true) {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline.timeline_id,
));
}
// Do some validation. We release the timelines lock below, so there is potential
// for race conditions: these checks are more present to prevent misunderstandings of
// the API's capabilities, instead of serving as the sole way to defend their invariants.
match new_state {
TimelineArchivalState::Unarchived => {
Self::check_to_be_unarchived_timeline_has_no_archived_parent(timeline)?
}
TimelineArchivalState::Archived => {
Self::check_to_be_archived_has_no_unarchived_children(timeline_id, &timelines)?
}
}
// Ensure that there are no non-archived child timelines
let children: Vec<TimelineId> = timelines
.iter()
.filter_map(|(id, entry)| {
if entry.get_ancestor_timeline_id() != Some(timeline_id) {
return None;
}
if entry.is_archived() == Some(true) {
return None;
}
Some(*id)
})
.collect();
if !children.is_empty() && state == TimelineArchivalState::Archived {
return Err(TimelineArchivalError::HasUnarchivedChildren(children));
}
Arc::clone(timeline)
Some(Arc::clone(timeline))
};
// Second part: unarchive timeline (if needed)
let timeline = if let Some(timeline) = timeline_or_unarchive_offloaded {
timeline
} else {
// Turn offloaded timeline into a non-offloaded one
self.unoffload_timeline(timeline_id, ctx).await?
};
// Third part: upload new timeline archival state and block until it is present in S3
let upload_needed = timeline
.remote_client
.schedule_index_upload_for_timeline_archival_state(state)?;
.schedule_index_upload_for_timeline_archival_state(new_state)?;
if upload_needed {
info!("Uploading new state");
@@ -1884,7 +2088,7 @@ impl Tenant {
///
/// Returns whether we have pending compaction task.
async fn compaction_iteration(
&self,
self: &Arc<Self>,
cancel: &CancellationToken,
ctx: &RequestContext,
) -> Result<bool, timeline::CompactionError> {
@@ -1905,21 +2109,28 @@ impl Tenant {
// while holding the lock. Then drop the lock and actually perform the
// compactions. We don't want to block everything else while the
// compaction runs.
let timelines_to_compact = {
let timelines_to_compact_or_offload;
{
let timelines = self.timelines.lock().unwrap();
let timelines_to_compact = timelines
timelines_to_compact_or_offload = timelines
.iter()
.filter_map(|(timeline_id, timeline)| {
if timeline.is_active() {
Some((*timeline_id, timeline.clone()))
} else {
let (is_active, can_offload) = (timeline.is_active(), timeline.can_offload());
let has_no_unoffloaded_children = {
!timelines
.iter()
.any(|(_id, tl)| tl.get_ancestor_timeline_id() == Some(*timeline_id))
};
let can_offload = can_offload && has_no_unoffloaded_children;
if (is_active, can_offload) == (false, false) {
None
} else {
Some((*timeline_id, timeline.clone(), (is_active, can_offload)))
}
})
.collect::<Vec<_>>();
drop(timelines);
timelines_to_compact
};
}
// Before doing any I/O work, check our circuit breaker
if self.compaction_circuit_breaker.lock().unwrap().is_broken() {
@@ -1929,20 +2140,34 @@ impl Tenant {
let mut has_pending_task = false;
for (timeline_id, timeline) in &timelines_to_compact {
has_pending_task |= timeline
.compact(cancel, EnumSet::empty(), ctx)
.instrument(info_span!("compact_timeline", %timeline_id))
.await
.inspect_err(|e| match e {
timeline::CompactionError::ShuttingDown => (),
timeline::CompactionError::Other(e) => {
self.compaction_circuit_breaker
.lock()
.unwrap()
.fail(&CIRCUIT_BREAKERS_BROKEN, e);
}
})?;
for (timeline_id, timeline, (can_compact, can_offload)) in &timelines_to_compact_or_offload
{
let pending_task_left = if *can_compact {
Some(
timeline
.compact(cancel, EnumSet::empty(), ctx)
.instrument(info_span!("compact_timeline", %timeline_id))
.await
.inspect_err(|e| match e {
timeline::CompactionError::ShuttingDown => (),
timeline::CompactionError::Other(e) => {
self.compaction_circuit_breaker
.lock()
.unwrap()
.fail(&CIRCUIT_BREAKERS_BROKEN, e);
}
})?,
)
} else {
None
};
has_pending_task |= pending_task_left.unwrap_or(false);
if pending_task_left == Some(false) && *can_offload {
offload_timeline(self, timeline)
.instrument(info_span!("offload_timeline", %timeline_id))
.await
.map_err(timeline::CompactionError::Other)?;
}
}
self.compaction_circuit_breaker
@@ -2852,6 +3077,7 @@ impl Tenant {
constructed_at: Instant::now(),
timelines: Mutex::new(HashMap::new()),
timelines_creating: Mutex::new(HashSet::new()),
timelines_offloaded: Mutex::new(HashMap::new()),
gc_cs: tokio::sync::Mutex::new(()),
walredo_mgr,
remote_storage,

View File

@@ -84,7 +84,7 @@ impl Drop for EphemeralFile {
fn drop(&mut self) {
// unlink the file
// we are clear to do this, because we have entered a gate
let path = &self.buffered_writer.as_inner().as_inner().path;
let path = self.buffered_writer.as_inner().as_inner().path();
let res = std::fs::remove_file(path);
if let Err(e) = res {
if e.kind() != std::io::ErrorKind::NotFound {
@@ -356,7 +356,7 @@ mod tests {
}
let file_contents =
std::fs::read(&file.buffered_writer.as_inner().as_inner().path).unwrap();
std::fs::read(file.buffered_writer.as_inner().as_inner().path()).unwrap();
assert_eq!(file_contents, &content[0..cap]);
let buffer_contents = file.buffered_writer.inspect_buffer();
@@ -392,7 +392,7 @@ mod tests {
.buffered_writer
.as_inner()
.as_inner()
.path
.path()
.metadata()
.unwrap();
assert_eq!(

View File

@@ -141,14 +141,14 @@ impl GcBlock {
Ok(())
}
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
pub(crate) fn before_delete(&self, timeline_id: &super::TimelineId) {
let unblocked = {
let mut g = self.reasons.lock().unwrap();
if g.is_empty() {
return;
}
g.remove(&timeline.timeline_id);
g.remove(timeline_id);
BlockingReasons::clean_and_summarize(g).is_none()
};

View File

@@ -950,6 +950,7 @@ impl<'a> TenantDownloader<'a> {
let cancel = &self.secondary_state.cancel;
let opts = DownloadOpts {
etag: prev_etag.cloned(),
..Default::default()
};
backoff::retry(

View File

@@ -573,7 +573,7 @@ impl DeltaLayerWriterInner {
ensure!(
metadata.len() <= S3_UPLOAD_LIMIT,
"Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!",
file.path,
file.path(),
metadata.len()
);
@@ -791,7 +791,7 @@ impl DeltaLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
let file = VirtualFile::open(path, ctx)
let file = VirtualFile::open_v2(path, ctx)
.await
.context("open layer file")?;
@@ -1022,7 +1022,7 @@ impl DeltaLayerInner {
blob_meta.key,
PageReconstructError::Other(anyhow!(
"Failed to read blobs from virtual file {}: {}",
self.file.path,
self.file.path(),
kind
)),
);
@@ -1048,7 +1048,7 @@ impl DeltaLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path,
self.file.path(),
))),
);
@@ -1066,7 +1066,7 @@ impl DeltaLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to deserialize blob from virtual file {}",
self.file.path,
self.file.path(),
))),
);
@@ -1198,7 +1198,6 @@ impl DeltaLayerInner {
let mut prev: Option<(Key, Lsn, BlobRef)> = None;
let mut read_builder: Option<ChunkedVectoredReadBuilder> = None;
let align = virtual_file::get_io_buffer_alignment();
let max_read_size = self
.max_vectored_read_bytes
@@ -1247,7 +1246,6 @@ impl DeltaLayerInner {
offsets.end.pos(),
meta,
max_read_size,
align,
))
}
} else {

View File

@@ -389,7 +389,7 @@ impl ImageLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
let file = VirtualFile::open(path, ctx)
let file = VirtualFile::open_v2(path, ctx)
.await
.context("open layer file")?;
let file_id = page_cache::next_file_id();
@@ -626,7 +626,7 @@ impl ImageLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path,
self.file.path(),
))),
);
@@ -647,7 +647,7 @@ impl ImageLayerInner {
blob_meta.key,
PageReconstructError::from(anyhow!(
"Failed to read blobs from virtual file {}: {}",
self.file.path,
self.file.path(),
kind
)),
);

View File

@@ -7,6 +7,7 @@ pub(crate) mod handle;
mod init;
pub mod layer_manager;
pub(crate) mod logical_size;
pub mod offload;
pub mod span;
pub mod uninit;
mod walreceiver;
@@ -1556,6 +1557,17 @@ impl Timeline {
}
}
/// Checks if the internal state of the timeline is consistent with it being able to be offloaded.
/// This is neccessary but not sufficient for offloading of the timeline as it might have
/// child timelines that are not offloaded yet.
pub(crate) fn can_offload(&self) -> bool {
if self.remote_client.is_archived() != Some(true) {
return false;
}
true
}
/// Outermost timeline compaction operation; downloads needed layers. Returns whether we have pending
/// compaction tasks.
pub(crate) async fn compact(
@@ -1818,7 +1830,6 @@ impl Timeline {
self.current_state() == TimelineState::Active
}
#[allow(unused)]
pub(crate) fn is_archived(&self) -> Option<bool> {
self.remote_client.is_archived()
}

View File

@@ -15,7 +15,7 @@ use crate::{
tenant::{
metadata::TimelineMetadata,
remote_timeline_client::{PersistIndexPartWithDeletedFlagError, RemoteTimelineClient},
CreateTimelineCause, DeleteTimelineError, Tenant,
CreateTimelineCause, DeleteTimelineError, Tenant, TimelineOrOffloaded,
},
};
@@ -24,12 +24,14 @@ use super::{Timeline, TimelineResources};
/// Mark timeline as deleted in S3 so we won't pick it up next time
/// during attach or pageserver restart.
/// See comment in persist_index_part_with_deleted_flag.
async fn set_deleted_in_remote_index(timeline: &Timeline) -> Result<(), DeleteTimelineError> {
match timeline
.remote_client
async fn set_deleted_in_remote_index(
timeline: &TimelineOrOffloaded,
) -> Result<(), DeleteTimelineError> {
let res = timeline
.remote_client()
.persist_index_part_with_deleted_flag()
.await
{
.await;
match res {
// If we (now, or already) marked it successfully as deleted, we can proceed
Ok(()) | Err(PersistIndexPartWithDeletedFlagError::AlreadyDeleted(_)) => (),
// Bail out otherwise
@@ -127,9 +129,9 @@ pub(super) async fn delete_local_timeline_directory(
}
/// Removes remote layers and an index file after them.
async fn delete_remote_layers_and_index(timeline: &Timeline) -> anyhow::Result<()> {
async fn delete_remote_layers_and_index(timeline: &TimelineOrOffloaded) -> anyhow::Result<()> {
timeline
.remote_client
.remote_client()
.delete_all()
.await
.context("delete_all")
@@ -137,27 +139,41 @@ async fn delete_remote_layers_and_index(timeline: &Timeline) -> anyhow::Result<(
/// It is important that this gets called when DeletionGuard is being held.
/// For more context see comments in [`DeleteTimelineFlow::prepare`]
async fn remove_timeline_from_tenant(
async fn remove_maybe_offloaded_timeline_from_tenant(
tenant: &Tenant,
timeline: &Timeline,
timeline: &TimelineOrOffloaded,
_: &DeletionGuard, // using it as a witness
) -> anyhow::Result<()> {
// Remove the timeline from the map.
// This observes the locking order between timelines and timelines_offloaded
let mut timelines = tenant.timelines.lock().unwrap();
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let offloaded_children_exist = timelines_offloaded
.iter()
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
let children_exist = timelines
.iter()
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id));
// XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`.
// We already deleted the layer files, so it's probably best to panic.
// (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart)
if children_exist {
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id()));
// XXX this can happen because of race conditions with branch creation.
// We already deleted the remote layer files, so it's probably best to panic.
if children_exist || offloaded_children_exist {
panic!("Timeline grew children while we removed layer files");
}
timelines
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines' map");
match timeline {
TimelineOrOffloaded::Timeline(timeline) => {
timelines.remove(&timeline.timeline_id).expect(
"timeline that we were deleting was concurrently removed from 'timelines' map",
);
}
TimelineOrOffloaded::Offloaded(timeline) => {
timelines_offloaded
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
}
}
drop(timelines_offloaded);
drop(timelines);
Ok(())
@@ -207,9 +223,11 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
timeline.shutdown(super::ShutdownMode::Hard).await;
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
tenant.gc_block.before_delete(&timeline);
tenant.gc_block.before_delete(&timeline.timeline_id());
fail::fail_point!("timeline-delete-before-index-deleted-at", |_| {
Err(anyhow::anyhow!(
@@ -285,15 +303,16 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
let timeline = TimelineOrOffloaded::Timeline(timeline);
Self::schedule_background(guard, tenant.conf, tenant, timeline);
Ok(())
}
fn prepare(
pub(super) fn prepare(
tenant: &Tenant,
timeline_id: TimelineId,
) -> Result<(Arc<Timeline>, DeletionGuard), DeleteTimelineError> {
) -> Result<(TimelineOrOffloaded, DeletionGuard), DeleteTimelineError> {
// Note the interaction between this guard and deletion guard.
// Here we attempt to lock deletion guard when we're holding a lock on timelines.
// This is important because when you take into account `remove_timeline_from_tenant`
@@ -307,8 +326,14 @@ impl DeleteTimelineFlow {
let timelines = tenant.timelines.lock().unwrap();
let timeline = match timelines.get(&timeline_id) {
Some(t) => t,
None => return Err(DeleteTimelineError::NotFound),
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
None => {
let offloaded_timelines = tenant.timelines_offloaded.lock().unwrap();
match offloaded_timelines.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
None => return Err(DeleteTimelineError::NotFound),
}
}
};
// Ensure that there are no child timelines **attached to that pageserver**,
@@ -334,30 +359,32 @@ impl DeleteTimelineFlow {
// to remove the timeline from it.
// Always if you have two locks that are taken in different order this can result in a deadlock.
let delete_progress = Arc::clone(&timeline.delete_progress);
let delete_progress = Arc::clone(timeline.delete_progress());
let delete_lock_guard = match delete_progress.try_lock_owned() {
Ok(guard) => DeletionGuard(guard),
Err(_) => {
// Unfortunately if lock fails arc is consumed.
return Err(DeleteTimelineError::AlreadyInProgress(Arc::clone(
&timeline.delete_progress,
timeline.delete_progress(),
)));
}
};
timeline.set_state(TimelineState::Stopping);
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.set_state(TimelineState::Stopping);
}
Ok((Arc::clone(timeline), delete_lock_guard))
Ok((timeline, delete_lock_guard))
}
fn schedule_background(
guard: DeletionGuard,
conf: &'static PageServerConf,
tenant: Arc<Tenant>,
timeline: Arc<Timeline>,
timeline: TimelineOrOffloaded,
) {
let tenant_shard_id = timeline.tenant_shard_id;
let timeline_id = timeline.timeline_id;
let tenant_shard_id = timeline.tenant_shard_id();
let timeline_id = timeline.timeline_id();
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
@@ -368,7 +395,9 @@ impl DeleteTimelineFlow {
async move {
if let Err(err) = Self::background(guard, conf, &tenant, &timeline).await {
error!("Error: {err:#}");
timeline.set_broken(format!("{err:#}"))
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
timeline.set_broken(format!("{err:#}"))
}
};
Ok(())
}
@@ -380,15 +409,19 @@ impl DeleteTimelineFlow {
mut guard: DeletionGuard,
conf: &PageServerConf,
tenant: &Tenant,
timeline: &Timeline,
timeline: &TimelineOrOffloaded,
) -> Result<(), DeleteTimelineError> {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?;
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?;
}
delete_remote_layers_and_index(timeline).await?;
pausable_failpoint!("in_progress_delete");
remove_timeline_from_tenant(tenant, timeline, &guard).await?;
remove_maybe_offloaded_timeline_from_tenant(tenant, timeline, &guard).await?;
*guard = Self::Finished;
@@ -400,7 +433,7 @@ impl DeleteTimelineFlow {
}
}
struct DeletionGuard(OwnedMutexGuard<DeleteTimelineFlow>);
pub(super) struct DeletionGuard(OwnedMutexGuard<DeleteTimelineFlow>);
impl Deref for DeletionGuard {
type Target = DeleteTimelineFlow;

View File

@@ -0,0 +1,69 @@
use std::sync::Arc;
use crate::tenant::{OffloadedTimeline, Tenant, TimelineOrOffloaded};
use super::{
delete::{delete_local_timeline_directory, DeleteTimelineFlow, DeletionGuard},
Timeline,
};
pub(crate) async fn offload_timeline(
tenant: &Tenant,
timeline: &Arc<Timeline>,
) -> anyhow::Result<()> {
tracing::info!("offloading archived timeline");
let (timeline, guard) = DeleteTimelineFlow::prepare(tenant, timeline.timeline_id)?;
let TimelineOrOffloaded::Timeline(timeline) = timeline else {
tracing::error!("timeline already offloaded, but given timeline object");
return Ok(());
};
// TODO extend guard mechanism above with method
// to make deletions possible while offloading is in progress
// TODO mark timeline as offloaded in S3
let conf = &tenant.conf;
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &timeline).await?;
remove_timeline_from_tenant(tenant, &timeline, &guard).await?;
{
let mut offloaded_timelines = tenant.timelines_offloaded.lock().unwrap();
offloaded_timelines.insert(
timeline.timeline_id,
Arc::new(OffloadedTimeline::from_timeline(&timeline)),
);
}
Ok(())
}
/// It is important that this gets called when DeletionGuard is being held.
/// For more context see comments in [`DeleteTimelineFlow::prepare`]
async fn remove_timeline_from_tenant(
tenant: &Tenant,
timeline: &Timeline,
_: &DeletionGuard, // using it as a witness
) -> anyhow::Result<()> {
// Remove the timeline from the map.
let mut timelines = tenant.timelines.lock().unwrap();
let children_exist = timelines
.iter()
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id));
// XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`.
// We already deleted the layer files, so it's probably best to panic.
// (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart)
if children_exist {
panic!("Timeline grew children while we removed layer files");
}
timelines
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines' map");
drop(timelines);
Ok(())
}

View File

@@ -194,8 +194,6 @@ pub(crate) struct ChunkedVectoredReadBuilder {
/// Start offset and metadata for each blob in this read
blobs_at: VecMap<u64, BlobMeta>,
max_read_size: Option<usize>,
/// Chunk size reads are coalesced into.
chunk_size: usize,
}
/// Computes x / d rounded up.
@@ -204,6 +202,7 @@ fn div_round_up(x: usize, d: usize) -> usize {
}
impl ChunkedVectoredReadBuilder {
const CHUNK_SIZE: usize = virtual_file::get_io_buffer_alignment();
/// Start building a new vectored read.
///
/// Note that by design, this does not check against reading more than `max_read_size` to
@@ -214,21 +213,19 @@ impl ChunkedVectoredReadBuilder {
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
chunk_size: usize,
) -> Self {
let mut blobs_at = VecMap::default();
blobs_at
.append(start_offset, meta)
.expect("First insertion always succeeds");
let start_blk_no = start_offset as usize / chunk_size;
let end_blk_no = div_round_up(end_offset as usize, chunk_size);
let start_blk_no = start_offset as usize / Self::CHUNK_SIZE;
let end_blk_no = div_round_up(end_offset as usize, Self::CHUNK_SIZE);
Self {
start_blk_no,
end_blk_no,
blobs_at,
max_read_size,
chunk_size,
}
}
@@ -237,18 +234,12 @@ impl ChunkedVectoredReadBuilder {
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align)
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size))
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, align)
pub(crate) fn new_streaming(start_offset: u64, end_offset: u64, meta: BlobMeta) -> Self {
Self::new_impl(start_offset, end_offset, meta, None)
}
/// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk.
@@ -256,12 +247,12 @@ impl ChunkedVectoredReadBuilder {
/// The resulting size also must be below the max read size.
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
tracing::trace!(start, end, "trying to extend");
let start_blk_no = start as usize / self.chunk_size;
let end_blk_no = div_round_up(end as usize, self.chunk_size);
let start_blk_no = start as usize / Self::CHUNK_SIZE;
let end_blk_no = div_round_up(end as usize, Self::CHUNK_SIZE);
let not_limited_by_max_read_size = {
if let Some(max_read_size) = self.max_read_size {
let coalesced_size = (end_blk_no - self.start_blk_no) * self.chunk_size;
let coalesced_size = (end_blk_no - self.start_blk_no) * Self::CHUNK_SIZE;
coalesced_size <= max_read_size
} else {
true
@@ -292,12 +283,12 @@ impl ChunkedVectoredReadBuilder {
}
pub(crate) fn size(&self) -> usize {
(self.end_blk_no - self.start_blk_no) * self.chunk_size
(self.end_blk_no - self.start_blk_no) * Self::CHUNK_SIZE
}
pub(crate) fn build(self) -> VectoredRead {
let start = (self.start_blk_no * self.chunk_size) as u64;
let end = (self.end_blk_no * self.chunk_size) as u64;
let start = (self.start_blk_no * Self::CHUNK_SIZE) as u64;
let end = (self.end_blk_no * Self::CHUNK_SIZE) as u64;
VectoredRead {
start,
end,
@@ -328,18 +319,14 @@ pub struct VectoredReadPlanner {
prev: Option<(Key, Lsn, u64, BlobFlag)>,
max_read_size: usize,
align: usize,
}
impl VectoredReadPlanner {
pub fn new(max_read_size: usize) -> Self {
let align = virtual_file::get_io_buffer_alignment();
Self {
blobs: BTreeMap::new(),
prev: None,
max_read_size,
align,
}
}
@@ -418,7 +405,6 @@ impl VectoredReadPlanner {
end_offset,
BlobMeta { key, lsn },
self.max_read_size,
self.align,
);
let prev_read_builder = current_read_builder.replace(next_read_builder);
@@ -472,13 +458,13 @@ impl<'a> VectoredBlobReader<'a> {
);
if cfg!(debug_assertions) {
let align = virtual_file::get_io_buffer_alignment() as u64;
const ALIGN: u64 = virtual_file::get_io_buffer_alignment() as u64;
debug_assert_eq!(
read.start % align,
read.start % ALIGN,
0,
"Read start at {} does not satisfy the required io buffer alignment ({} bytes)",
read.start,
align
ALIGN
);
}
@@ -553,22 +539,18 @@ pub struct StreamingVectoredReadPlanner {
max_cnt: usize,
/// Size of the current batch
cnt: usize,
align: usize,
}
impl StreamingVectoredReadPlanner {
pub fn new(max_read_size: u64, max_cnt: usize) -> Self {
assert!(max_cnt > 0);
assert!(max_read_size > 0);
let align = virtual_file::get_io_buffer_alignment();
Self {
read_builder: None,
prev: None,
max_cnt,
max_read_size,
cnt: 0,
align,
}
}
@@ -621,7 +603,6 @@ impl StreamingVectoredReadPlanner {
start_offset,
end_offset,
BlobMeta { key, lsn },
self.align,
))
};
}
@@ -656,9 +637,9 @@ mod tests {
use super::*;
fn validate_read(read: &VectoredRead, offset_range: &[(Key, Lsn, u64, BlobFlag)]) {
let align = virtual_file::get_io_buffer_alignment() as u64;
assert_eq!(read.start % align, 0);
assert_eq!(read.start / align, offset_range.first().unwrap().2 / align);
const ALIGN: u64 = virtual_file::get_io_buffer_alignment() as u64;
assert_eq!(read.start % ALIGN, 0);
assert_eq!(read.start / ALIGN, offset_range.first().unwrap().2 / ALIGN);
let expected_offsets_in_read: Vec<_> = offset_range.iter().map(|o| o.2).collect();
@@ -676,32 +657,27 @@ mod tests {
fn planner_chunked_coalesce_all_test() {
use crate::virtual_file;
let chunk_size = virtual_file::get_io_buffer_alignment() as u64;
const CHUNK_SIZE: u64 = virtual_file::get_io_buffer_alignment() as u64;
// The test explicitly does not check chunk size < 512
if chunk_size < 512 {
return;
}
let max_read_size = chunk_size as usize * 8;
let max_read_size = CHUNK_SIZE as usize * 8;
let key = Key::MIN;
let lsn = Lsn(0);
let blob_descriptions = [
(key, lsn, chunk_size / 8, BlobFlag::None), // Read 1 BEGIN
(key, lsn, chunk_size / 4, BlobFlag::Ignore), // Gap
(key, lsn, chunk_size / 2, BlobFlag::None),
(key, lsn, chunk_size - 2, BlobFlag::Ignore), // Gap
(key, lsn, chunk_size, BlobFlag::None),
(key, lsn, chunk_size * 2 - 1, BlobFlag::None),
(key, lsn, chunk_size * 2 + 1, BlobFlag::Ignore), // Gap
(key, lsn, chunk_size * 3 + 1, BlobFlag::None),
(key, lsn, chunk_size * 5 + 1, BlobFlag::None),
(key, lsn, chunk_size * 6 + 1, BlobFlag::Ignore), // skipped chunk size, but not a chunk: should coalesce.
(key, lsn, chunk_size * 7 + 1, BlobFlag::None),
(key, lsn, chunk_size * 8, BlobFlag::None), // Read 2 BEGIN (b/c max_read_size)
(key, lsn, chunk_size * 9, BlobFlag::Ignore), // ==== skipped a chunk
(key, lsn, chunk_size * 10, BlobFlag::None), // Read 3 BEGIN (cannot coalesce)
(key, lsn, CHUNK_SIZE / 8, BlobFlag::None), // Read 1 BEGIN
(key, lsn, CHUNK_SIZE / 4, BlobFlag::Ignore), // Gap
(key, lsn, CHUNK_SIZE / 2, BlobFlag::None),
(key, lsn, CHUNK_SIZE - 2, BlobFlag::Ignore), // Gap
(key, lsn, CHUNK_SIZE, BlobFlag::None),
(key, lsn, CHUNK_SIZE * 2 - 1, BlobFlag::None),
(key, lsn, CHUNK_SIZE * 2 + 1, BlobFlag::Ignore), // Gap
(key, lsn, CHUNK_SIZE * 3 + 1, BlobFlag::None),
(key, lsn, CHUNK_SIZE * 5 + 1, BlobFlag::None),
(key, lsn, CHUNK_SIZE * 6 + 1, BlobFlag::Ignore), // skipped chunk size, but not a chunk: should coalesce.
(key, lsn, CHUNK_SIZE * 7 + 1, BlobFlag::None),
(key, lsn, CHUNK_SIZE * 8, BlobFlag::None), // Read 2 BEGIN (b/c max_read_size)
(key, lsn, CHUNK_SIZE * 9, BlobFlag::Ignore), // ==== skipped a chunk
(key, lsn, CHUNK_SIZE * 10, BlobFlag::None), // Read 3 BEGIN (cannot coalesce)
];
let ranges = [
@@ -780,19 +756,19 @@ mod tests {
#[test]
fn planner_replacement_test() {
let chunk_size = virtual_file::get_io_buffer_alignment() as u64;
let max_read_size = 128 * chunk_size as usize;
const CHUNK_SIZE: u64 = virtual_file::get_io_buffer_alignment() as u64;
let max_read_size = 128 * CHUNK_SIZE as usize;
let first_key = Key::MIN;
let second_key = first_key.next();
let lsn = Lsn(0);
let blob_descriptions = vec![
(first_key, lsn, 0, BlobFlag::None), // First in read 1
(first_key, lsn, chunk_size, BlobFlag::None), // Last in read 1
(second_key, lsn, 2 * chunk_size, BlobFlag::ReplaceAll),
(second_key, lsn, 3 * chunk_size, BlobFlag::None),
(second_key, lsn, 4 * chunk_size, BlobFlag::ReplaceAll), // First in read 2
(second_key, lsn, 5 * chunk_size, BlobFlag::None), // Last in read 2
(first_key, lsn, CHUNK_SIZE, BlobFlag::None), // Last in read 1
(second_key, lsn, 2 * CHUNK_SIZE, BlobFlag::ReplaceAll),
(second_key, lsn, 3 * CHUNK_SIZE, BlobFlag::None),
(second_key, lsn, 4 * CHUNK_SIZE, BlobFlag::ReplaceAll), // First in read 2
(second_key, lsn, 5 * CHUNK_SIZE, BlobFlag::None), // Last in read 2
];
let ranges = [&blob_descriptions[0..2], &blob_descriptions[4..]];
@@ -802,7 +778,7 @@ mod tests {
planner.handle(key, lsn, offset, flag);
}
planner.handle_range_end(6 * chunk_size);
planner.handle_range_end(6 * CHUNK_SIZE);
let reads = planner.finish();
assert_eq!(reads.len(), 2);
@@ -947,7 +923,6 @@ mod tests {
let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16;
let mut buf = BytesMut::with_capacity(reserved_bytes);
let align = virtual_file::get_io_buffer_alignment();
let vectored_blob_reader = VectoredBlobReader::new(&file);
let meta = BlobMeta {
key: Key::MIN,
@@ -959,8 +934,7 @@ mod tests {
if idx + 1 == offsets.len() {
continue;
}
let read_builder =
ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align);
let read_builder = ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096);
let read = read_builder.build();
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
assert_eq!(result.blobs.len(), 1);

View File

@@ -23,10 +23,12 @@ use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
use pageserver_api::shard::TenantShardId;
use std::fs::File;
use std::io::{Error, ErrorKind, Seek, SeekFrom};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
@@ -38,7 +40,7 @@ pub use io_engine::FeatureTestResult as IoEngineFeatureTestResult;
mod metadata;
mod open_options;
use self::owned_buffers_io::write::OwnedAsyncWriter;
pub(crate) use api::DirectIoMode;
pub(crate) use api::IoMode;
pub(crate) use io_engine::IoEngineKind;
pub(crate) use metadata::Metadata;
pub(crate) use open_options::*;
@@ -61,6 +63,171 @@ pub(crate) mod owned_buffers_io {
}
}
#[derive(Debug)]
pub struct VirtualFile {
inner: VirtualFileInner,
_mode: IoMode,
}
impl VirtualFile {
/// Open a file in read-only mode. Like File::open.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::open(path, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
}
/// Open a file in read-only mode. Like File::open.
///
/// `O_DIRECT` will be enabled base on `virtual_file_io_mode`.
pub async fn open_v2<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::create(path, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
}
pub async fn create_v2<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
VirtualFile::open_with_options_v2(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
ctx,
)
.await
}
pub async fn open_with_options<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
}
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let file = match get_io_mode() {
IoMode::Buffered => {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
VirtualFile {
inner,
_mode: IoMode::Buffered,
}
}
#[cfg(target_os = "linux")]
IoMode::Direct => {
let inner = VirtualFileInner::open_with_options(
path,
open_options.clone().custom_flags(nix::libc::O_DIRECT),
ctx,
)
.await?;
VirtualFile {
inner,
_mode: IoMode::Direct,
}
}
};
Ok(file)
}
pub fn path(&self) -> &Utf8Path {
self.inner.path.as_path()
}
pub async fn crashsafe_overwrite<B: BoundedBuf<Buf = Buf> + Send, Buf: IoBuf + Send>(
final_path: Utf8PathBuf,
tmp_path: Utf8PathBuf,
content: B,
) -> std::io::Result<()> {
VirtualFileInner::crashsafe_overwrite(final_path, tmp_path, content).await
}
pub async fn sync_all(&self) -> Result<(), Error> {
self.inner.sync_all().await
}
pub async fn sync_data(&self) -> Result<(), Error> {
self.inner.sync_data().await
}
pub async fn metadata(&self) -> Result<Metadata, Error> {
self.inner.metadata().await
}
pub fn remove(self) {
self.inner.remove();
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
self.inner.seek(pos).await
}
pub async fn read_exact_at<Buf>(
&self,
slice: Slice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> Result<Slice<Buf>, Error>
where
Buf: IoBufMut + Send,
{
self.inner.read_exact_at(slice, offset, ctx).await
}
pub async fn read_exact_at_page(
&self,
page: PageWriteGuard<'static>,
offset: u64,
ctx: &RequestContext,
) -> Result<PageWriteGuard<'static>, Error> {
self.inner.read_exact_at_page(page, offset, ctx).await
}
pub async fn write_all_at<Buf: IoBuf + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
self.inner.write_all_at(buf, offset, ctx).await
}
pub async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<usize, Error>) {
self.inner.write_all(buf, ctx).await
}
}
///
/// A virtual file descriptor. You can use this just like std::fs::File, but internally
/// the underlying file is closed if the system is low on file descriptors,
@@ -77,7 +244,7 @@ pub(crate) mod owned_buffers_io {
/// 'tag' field is used to detect whether the handle still is valid or not.
///
#[derive(Debug)]
pub struct VirtualFile {
pub struct VirtualFileInner {
/// Lazy handle to the global file descriptor cache. The slot that this points to
/// might contain our File, or it may be empty, or it may contain a File that
/// belongs to a different VirtualFile.
@@ -350,12 +517,12 @@ macro_rules! with_file {
}};
}
impl VirtualFile {
impl VirtualFileInner {
/// Open a file in read-only mode. Like File::open.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<VirtualFile, std::io::Error> {
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
@@ -364,7 +531,7 @@ impl VirtualFile {
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<VirtualFile, std::io::Error> {
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
@@ -382,7 +549,7 @@ impl VirtualFile {
path: P,
open_options: &OpenOptions,
_ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<VirtualFile, std::io::Error> {
) -> Result<VirtualFileInner, std::io::Error> {
let path_ref = path.as_ref();
let path_str = path_ref.to_string();
let parts = path_str.split('/').collect::<Vec<&str>>();
@@ -423,7 +590,7 @@ impl VirtualFile {
reopen_options.create_new(false);
reopen_options.truncate(false);
let vfile = VirtualFile {
let vfile = VirtualFileInner {
handle: RwLock::new(handle),
pos: 0,
path: path_ref.to_path_buf(),
@@ -1034,6 +1201,21 @@ impl tokio_epoll_uring::IoFd for FileGuard {
#[cfg(test)]
impl VirtualFile {
pub(crate) async fn read_blk(
&self,
blknum: u32,
ctx: &RequestContext,
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
self.inner.read_blk(blknum, ctx).await
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
self.inner.read_to_end(buf, ctx).await
}
}
#[cfg(test)]
impl VirtualFileInner {
pub(crate) async fn read_blk(
&self,
blknum: u32,
@@ -1067,7 +1249,7 @@ impl VirtualFile {
}
}
impl Drop for VirtualFile {
impl Drop for VirtualFileInner {
/// If a VirtualFile is dropped, close the underlying file if it was open.
fn drop(&mut self) {
let handle = self.handle.get_mut();
@@ -1143,15 +1325,10 @@ impl OpenFiles {
/// server startup.
///
#[cfg(not(test))]
pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize) {
pub fn init(num_slots: usize, engine: IoEngineKind) {
if OPEN_FILES.set(OpenFiles::new(num_slots)).is_err() {
panic!("virtual_file::init called twice");
}
if set_io_buffer_alignment(io_buffer_alignment).is_err() {
panic!(
"IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}"
);
}
io_engine::init(engine);
crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64);
}
@@ -1175,47 +1352,20 @@ fn get_open_files() -> &'static OpenFiles {
}
}
static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT);
/// Returns true if the alignment is a power of two and is greater or equal to 512.
fn is_valid_io_buffer_alignment(align: usize) -> bool {
align.is_power_of_two() && align >= 512
}
/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is
/// not a power of two or less than 512 bytes.
#[allow(unused)]
pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
if is_valid_io_buffer_alignment(align) {
IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed);
Ok(())
} else {
Err(align)
}
}
/// Gets the io buffer alignment.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed);
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT";
if let Some(test_align) = utils::env::var(env_var_name) {
if is_valid_io_buffer_alignment(test_align) {
test_align
} else {
panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}");
}
} else {
align
}
} else {
align
}
pub(crate) const fn get_io_buffer_alignment() -> usize {
DEFAULT_IO_BUFFER_ALIGNMENT
}
static IO_MODE: AtomicU8 = AtomicU8::new(IoMode::preferred() as u8);
pub(crate) fn set_io_mode(mode: IoMode) {
IO_MODE.store(mode as u8, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get_io_mode() -> IoMode {
IoMode::try_from(IO_MODE.load(Ordering::Relaxed)).unwrap()
}
#[cfg(test)]
mod tests {
use crate::context::DownloadBehavior;
@@ -1524,7 +1674,7 @@ mod tests {
// Open the file many times.
let mut files = Vec::new();
for _ in 0..VIRTUAL_FILES {
let f = VirtualFile::open_with_options(
let f = VirtualFileInner::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
&ctx,
@@ -1576,7 +1726,7 @@ mod tests {
let path = testdir.join("myfile");
let tmp_path = testdir.join("myfile.tmp");
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
@@ -1585,7 +1735,7 @@ mod tests {
assert!(!tmp_path.exists());
drop(file);
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
@@ -1608,7 +1758,7 @@ mod tests {
std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap();
assert!(tmp_path.exists());
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
.await
.unwrap();

View File

@@ -146,6 +146,8 @@ ConstructDeltaMessage()
if (RootTable.role_table)
{
JsonbValue roles;
HASH_SEQ_STATUS status;
RoleEntry *entry;
roles.type = jbvString;
roles.val.string.val = "roles";
@@ -153,9 +155,6 @@ ConstructDeltaMessage()
pushJsonbValue(&state, WJB_KEY, &roles);
pushJsonbValue(&state, WJB_BEGIN_ARRAY, NULL);
HASH_SEQ_STATUS status;
RoleEntry *entry;
hash_seq_init(&status, RootTable.role_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -190,10 +189,12 @@ ConstructDeltaMessage()
}
pushJsonbValue(&state, WJB_END_ARRAY, NULL);
}
JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL);
Jsonb *jsonb = JsonbValueToJsonb(result);
{
JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL);
Jsonb *jsonb = JsonbValueToJsonb(result);
return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ );
return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ );
}
}
#define ERROR_SIZE 1024
@@ -272,32 +273,28 @@ SendDeltasToControlPlane()
curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, ErrorWriteCallback);
}
char *message = ConstructDeltaMessage();
ErrorString str;
str.size = 0;
curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str);
const int num_retries = 5;
CURLcode curl_status;
for (int i = 0; i < num_retries; i++)
{
if ((curl_status = curl_easy_perform(handle)) == 0)
break;
elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf);
pg_usleep(1000 * 1000);
}
if (curl_status != CURLE_OK)
{
elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf);
}
else
{
char *message = ConstructDeltaMessage();
ErrorString str;
const int num_retries = 5;
CURLcode curl_status;
long response_code;
str.size = 0;
curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str);
for (int i = 0; i < num_retries; i++)
{
if ((curl_status = curl_easy_perform(handle)) == 0)
break;
elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf);
pg_usleep(1000 * 1000);
}
if (curl_status != CURLE_OK)
elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf);
if (curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code) != CURLE_UNKNOWN_OPTION)
{
if (response_code != 200)
@@ -376,10 +373,11 @@ MergeTable()
if (old_table->db_table)
{
InitDbTableIfNeeded();
DbEntry *entry;
HASH_SEQ_STATUS status;
InitDbTableIfNeeded();
hash_seq_init(&status, old_table->db_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -421,10 +419,11 @@ MergeTable()
if (old_table->role_table)
{
InitRoleTableIfNeeded();
RoleEntry *entry;
HASH_SEQ_STATUS status;
InitRoleTableIfNeeded();
hash_seq_init(&status, old_table->role_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -515,9 +514,12 @@ RoleIsNeonSuperuser(const char *role_name)
static void
HandleCreateDb(CreatedbStmt *stmt)
{
InitDbTableIfNeeded();
DefElem *downer = NULL;
ListCell *option;
bool found = false;
DbEntry *entry;
InitDbTableIfNeeded();
foreach(option, stmt->options)
{
@@ -526,13 +528,11 @@ HandleCreateDb(CreatedbStmt *stmt)
if (strcmp(defel->defname, "owner") == 0)
downer = defel;
}
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
entry = hash_search(CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
@@ -554,21 +554,24 @@ HandleCreateDb(CreatedbStmt *stmt)
static void
HandleAlterOwner(AlterOwnerStmt *stmt)
{
const char *name;
bool found = false;
DbEntry *entry;
const char *new_owner;
if (stmt->objectType != OBJECT_DATABASE)
return;
InitDbTableIfNeeded();
const char *name = strVal(stmt->object);
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
name,
HASH_ENTER,
&found);
name = strVal(stmt->object);
entry = hash_search(CurrentDdlTable->db_table,
name,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
const char *new_owner = get_rolespec_name(stmt->newowner);
new_owner = get_rolespec_name(stmt->newowner);
if (RoleIsNeonSuperuser(new_owner))
elog(ERROR, "can't alter owner to neon_superuser");
entry->owner = get_role_oid(new_owner, false);
@@ -578,21 +581,23 @@ HandleAlterOwner(AlterOwnerStmt *stmt)
static void
HandleDbRename(RenameStmt *stmt)
{
bool found = false;
DbEntry *entry;
DbEntry *entry_for_new_name;
Assert(stmt->renameType == OBJECT_DATABASE);
InitDbTableIfNeeded();
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->subname,
HASH_FIND,
&found);
DbEntry *entry_for_new_name = hash_search(
CurrentDdlTable->db_table,
stmt->newname,
HASH_ENTER,
NULL);
entry = hash_search(CurrentDdlTable->db_table,
stmt->subname,
HASH_FIND,
&found);
entry_for_new_name = hash_search(CurrentDdlTable->db_table,
stmt->newname,
HASH_ENTER,
NULL);
entry_for_new_name->type = Op_Set;
if (found)
{
if (entry->old_name[0] != '\0')
@@ -600,8 +605,7 @@ HandleDbRename(RenameStmt *stmt)
else
strlcpy(entry_for_new_name->old_name, entry->name, NAMEDATALEN);
entry_for_new_name->owner = entry->owner;
hash_search(
CurrentDdlTable->db_table,
hash_search(CurrentDdlTable->db_table,
stmt->subname,
HASH_REMOVE,
NULL);
@@ -616,14 +620,15 @@ HandleDbRename(RenameStmt *stmt)
static void
HandleDropDb(DropdbStmt *stmt)
{
InitDbTableIfNeeded();
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
DbEntry *entry;
InitDbTableIfNeeded();
entry = hash_search(CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
entry->type = Op_Delete;
entry->owner = InvalidOid;
if (!found)
@@ -633,16 +638,14 @@ HandleDropDb(DropdbStmt *stmt)
static void
HandleCreateRole(CreateRoleStmt *stmt)
{
InitRoleTableIfNeeded();
bool found = false;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
stmt->role,
HASH_ENTER,
&found);
DefElem *dpass = NULL;
RoleEntry *entry;
DefElem *dpass;
ListCell *option;
InitRoleTableIfNeeded();
dpass = NULL;
foreach(option, stmt->options)
{
DefElem *defel = lfirst(option);
@@ -650,6 +653,11 @@ HandleCreateRole(CreateRoleStmt *stmt)
if (strcmp(defel->defname, "password") == 0)
dpass = defel;
}
entry = hash_search(CurrentDdlTable->role_table,
stmt->role,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
if (dpass && dpass->arg)
@@ -662,14 +670,18 @@ HandleCreateRole(CreateRoleStmt *stmt)
static void
HandleAlterRole(AlterRoleStmt *stmt)
{
InitRoleTableIfNeeded();
DefElem *dpass = NULL;
ListCell *option;
const char *role_name = stmt->role->rolename;
DefElem *dpass;
ListCell *option;
bool found = false;
RoleEntry *entry;
InitRoleTableIfNeeded();
if (RoleIsNeonSuperuser(role_name) && !superuser())
elog(ERROR, "can't ALTER neon_superuser");
dpass = NULL;
foreach(option, stmt->options)
{
DefElem *defel = lfirst(option);
@@ -680,13 +692,11 @@ HandleAlterRole(AlterRoleStmt *stmt)
/* We only care about updates to the password */
if (!dpass)
return;
bool found = false;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
role_name,
HASH_ENTER,
&found);
entry = hash_search(CurrentDdlTable->role_table,
role_name,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
if (dpass->arg)
@@ -699,20 +709,22 @@ HandleAlterRole(AlterRoleStmt *stmt)
static void
HandleRoleRename(RenameStmt *stmt)
{
InitRoleTableIfNeeded();
Assert(stmt->renameType == OBJECT_ROLE);
bool found = false;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
stmt->subname,
HASH_FIND,
&found);
RoleEntry *entry;
RoleEntry *entry_for_new_name;
RoleEntry *entry_for_new_name = hash_search(
CurrentDdlTable->role_table,
stmt->newname,
HASH_ENTER,
NULL);
Assert(stmt->renameType == OBJECT_ROLE);
InitRoleTableIfNeeded();
entry = hash_search(CurrentDdlTable->role_table,
stmt->subname,
HASH_FIND,
&found);
entry_for_new_name = hash_search(CurrentDdlTable->role_table,
stmt->newname,
HASH_ENTER,
NULL);
entry_for_new_name->type = Op_Set;
if (found)
@@ -738,9 +750,10 @@ HandleRoleRename(RenameStmt *stmt)
static void
HandleDropRole(DropRoleStmt *stmt)
{
InitRoleTableIfNeeded();
ListCell *item;
InitRoleTableIfNeeded();
foreach(item, stmt->roles)
{
RoleSpec *spec = lfirst(item);

View File

@@ -170,12 +170,14 @@ lfc_disable(char const *op)
if (lfc_desc > 0)
{
int rc;
/*
* If the reason of error is ENOSPC, then truncation of file may
* help to reclaim some space
*/
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_TRUNCATE);
int rc = ftruncate(lfc_desc, 0);
rc = ftruncate(lfc_desc, 0);
pgstat_report_wait_end();
if (rc < 0)
@@ -616,7 +618,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
*/
if (entry->bitmap[chunk_offs >> 5] == 0)
{
bool has_remaining_pages;
bool has_remaining_pages = false;
for (int i = 0; i < CHUNK_BITMAP_SIZE; i++)
{
@@ -666,7 +668,6 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
bool result = true;
uint32 hash;
uint64 generation;
uint32 entry_offset;
@@ -925,10 +926,10 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool found;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &found);
CriticalAssert(found);
hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1004,7 +1005,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
Datum result;
HeapTuple tuple;
char const *key;
uint64 value;
uint64 value = 0;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];

View File

@@ -116,8 +116,6 @@ addSHLL(HyperLogLogState *cState, uint32 hash)
{
uint8 count;
uint32 index;
size_t i;
size_t j;
TimestampTz now = GetCurrentTimestamp();
/* Use the first "k" (registerWidth) bits as a zero based index */

View File

@@ -89,7 +89,6 @@ typedef struct
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void walproposer_shmem_request(void);
#endif
static shmem_startup_hook_type prev_shmem_startup_hook;
static PagestoreShmemState *pagestore_shared;
@@ -441,8 +440,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
return false;
}
shard->state = PS_Connecting_Startup;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connecting_Startup:
{
char *pagestream_query;
@@ -453,8 +452,6 @@ pageserver_connect(shardno_t shard_no, int elevel)
do
{
WaitEvent event;
switch (poll_result)
{
default: /* unknown/unused states are handled as a failed connection */
@@ -585,8 +582,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
shard->state = PS_Connecting_PageStream;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connecting_PageStream:
{
neon_shard_log(shard_no, DEBUG5, "Connection state: Connecting_PageStream");
@@ -631,8 +628,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
shard->state = PS_Connected;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connected:
/*
* We successfully connected. Future connections to this PageServer

View File

@@ -94,7 +94,6 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t));
uint64 bucket_accum;
int i = 0;
Datum getpage_wait_str;
metrics[i].name = "getpage_wait_seconds_count";
metrics[i].is_bucket = false;
@@ -224,7 +223,6 @@ neon_get_perf_counters(PG_FUNCTION_ARGS)
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
Datum values[3];
bool nulls[3];
Datum getpage_wait_str;
neon_per_backend_counters totals = {0};
metric_t *metrics;

View File

@@ -7,6 +7,7 @@
#define NEON_PGVERSIONCOMPAT_H
#include "fmgr.h"
#include "storage/buf_internals.h"
#if PG_MAJORVERSION_NUM < 17
#define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId)
@@ -20,11 +21,24 @@
NInfoGetRelNumber(a) == NInfoGetRelNumber(b) \
)
/* buftag population & RelFileNode/RelFileLocator rework */
/* These macros were turned into static inline functions in v16 */
#if PG_MAJORVERSION_NUM < 16
static inline bool
BufferTagsEqual(const BufferTag *tag1, const BufferTag *tag2)
{
return BUFFERTAGS_EQUAL(*tag1, *tag2);
}
#define InitBufferTag(tag, rfn, fn, bn) INIT_BUFFERTAG(*tag, *rfn, fn, bn)
static inline void
InitBufferTag(BufferTag *tag, const RelFileNode *rnode,
ForkNumber forkNum, BlockNumber blockNum)
{
INIT_BUFFERTAG(*tag, *rnode, forkNum, blockNum);
}
#endif
/* RelFileNode -> RelFileLocator rework */
#if PG_MAJORVERSION_NUM < 16
#define USE_RELFILENODE
#define RELFILEINFO_HDR "storage/relfilenode.h"
@@ -73,8 +87,6 @@
#define USE_RELFILELOCATOR
#define BUFFERTAGS_EQUAL(a, b) BufferTagsEqual(&(a), &(b))
#define RELFILEINFO_HDR "storage/relfilelocator.h"
#define NRelFileInfo RelFileLocator

View File

@@ -213,32 +213,6 @@ extern const f_smgr *smgr_neon(ProcNumber backend, NRelFileInfo rinfo);
extern void smgr_init_neon(void);
extern void readahead_buffer_resize(int newsize, void *extra);
/* Neon storage manager functionality */
extern void neon_init(void);
extern void neon_open(SMgrRelation reln);
extern void neon_close(SMgrRelation reln, ForkNumber forknum);
extern void neon_create(SMgrRelation reln, ForkNumber forknum, bool isRedo);
extern bool neon_exists(SMgrRelation reln, ForkNumber forknum);
extern void neon_unlink(NRelFileInfoBackend rnode, ForkNumber forknum, bool isRedo);
#if PG_MAJORVERSION_NUM < 16
extern void neon_extend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, char *buffer, bool skipFsync);
#else
extern void neon_extend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, const void *buffer, bool skipFsync);
extern void neon_zeroextend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, int nbuffers, bool skipFsync);
#endif
#if PG_MAJORVERSION_NUM >=17
extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, int nblocks);
#else
extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum);
#endif
/*
* LSN values associated with each request to the pageserver
*/
@@ -278,13 +252,7 @@ extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum,
extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum, BlockNumber blkno,
neon_request_lsns request_lsns, void *buffer);
#endif
extern void neon_writeback(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, BlockNumber nblocks);
extern BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum);
extern int64 neon_dbsize(Oid dbNode);
extern void neon_truncate(SMgrRelation reln, ForkNumber forknum,
BlockNumber nblocks);
extern void neon_immedsync(SMgrRelation reln, ForkNumber forknum);
/* utils for neon relsize cache */
extern void relsize_hash_init(void);

View File

@@ -118,6 +118,8 @@ static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
static bool neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id);
static bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) = NULL;
static BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum);
/*
* Prefetch implementation:
*
@@ -215,7 +217,7 @@ typedef struct PrfHashEntry
sizeof(BufferTag) \
)
#define SH_EQUAL(tb, a, b) (BUFFERTAGS_EQUAL((a)->buftag, (b)->buftag))
#define SH_EQUAL(tb, a, b) (BufferTagsEqual(&(a)->buftag, &(b)->buftag))
#define SH_SCOPE static inline
#define SH_DEFINE
#define SH_DECLARE
@@ -736,7 +738,7 @@ static void
prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns)
{
bool found;
uint64 mySlotNo = slot->my_ring_index;
uint64 mySlotNo PG_USED_FOR_ASSERTS_ONLY = slot->my_ring_index;
NeonGetPageRequest request = {
.req.tag = T_NeonGetPageRequest,
@@ -853,7 +855,7 @@ Retry:
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(BUFFERTAGS_EQUAL(slot->buftag, hashkey.buftag));
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
* If the caller specified a request LSN to use, only accept
@@ -1463,7 +1465,6 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno,
BlockNumber blknos[XLR_MAX_BLOCK_ID];
Page pageptrs[XLR_MAX_BLOCK_ID];
int nregistered = 0;
XLogRecPtr result = 0;
for (int i = 0; i < nblocks; i++)
{
@@ -1776,7 +1777,7 @@ neon_wallog_page(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, co
/*
* neon_init() -- Initialize private state
*/
void
static void
neon_init(void)
{
Size prfs_size;
@@ -2166,7 +2167,7 @@ neon_prefetch_response_usable(neon_request_lsns *request_lsns,
/*
* neon_exists() -- Does the physical file exist?
*/
bool
static bool
neon_exists(SMgrRelation reln, ForkNumber forkNum)
{
bool exists;
@@ -2272,7 +2273,7 @@ neon_exists(SMgrRelation reln, ForkNumber forkNum)
*
* If isRedo is true, it's okay for the relation to exist already.
*/
void
static void
neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
{
switch (reln->smgr_relpersistence)
@@ -2348,7 +2349,7 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
* Note: any failure should be reported as WARNING not ERROR, because
* we are usually not in a transaction anymore when this is called.
*/
void
static void
neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
@@ -2372,7 +2373,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
* EOF). Note that we assume writing a block beyond current EOF
* causes intervening file space to become filled with zeroes.
*/
void
static void
#if PG_MAJORVERSION_NUM < 16
neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
char *buffer, bool skipFsync)
@@ -2464,7 +2465,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
}
#if PG_MAJORVERSION_NUM >= 16
void
static void
neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
int nblocks, bool skipFsync)
{
@@ -2560,7 +2561,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
/*
* neon_open() -- Initialize newly-opened relation.
*/
void
static void
neon_open(SMgrRelation reln)
{
/*
@@ -2578,7 +2579,7 @@ neon_open(SMgrRelation reln)
/*
* neon_close() -- Close the specified relation, if it isn't closed already.
*/
void
static void
neon_close(SMgrRelation reln, ForkNumber forknum)
{
/*
@@ -2593,13 +2594,12 @@ neon_close(SMgrRelation reln, ForkNumber forknum)
/*
* neon_prefetch() -- Initiate asynchronous read of the specified block of a relation
*/
bool
static bool
neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
int nblocks)
{
uint64 ring_index PG_USED_FOR_ASSERTS_ONLY;
BufferTag tag;
bool io_initiated = false;
switch (reln->smgr_relpersistence)
{
@@ -2623,7 +2623,6 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
while (nblocks > 0)
{
int iterblocks = Min(nblocks, PG_IOV_MAX);
int seqlen = 0;
bits8 lfc_present[PG_IOV_MAX / 8];
memset(lfc_present, 0, sizeof(lfc_present));
@@ -2635,8 +2634,6 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
continue;
}
io_initiated = true;
tag.blockNum = blocknum;
for (int i = 0; i < PG_IOV_MAX / 8; i++)
@@ -2659,7 +2656,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
/*
* neon_prefetch() -- Initiate asynchronous read of the specified block of a relation
*/
bool
static bool
neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
{
uint64 ring_index PG_USED_FOR_ASSERTS_ONLY;
@@ -2703,7 +2700,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
* This accepts a range of blocks because flushing several pages at once is
* considerably more efficient than doing so individually.
*/
void
static void
neon_writeback(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, BlockNumber nblocks)
{
@@ -2924,10 +2921,10 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* neon_read() -- Read the specified block from a relation.
*/
#if PG_MAJORVERSION_NUM < 16
void
static void
neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, char *buffer)
#else
void
static void
neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer)
#endif
{
@@ -3036,7 +3033,7 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
void
static void
neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
void **buffers, BlockNumber nblocks)
{
@@ -3200,6 +3197,7 @@ hexdump_page(char *page)
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
* neon_write() -- Write the supplied block at the appropriate location.
*
@@ -3207,7 +3205,7 @@ hexdump_page(char *page)
* relation (ie, those before the current EOF). To extend a relation,
* use mdextend().
*/
void
static void
#if PG_MAJORVERSION_NUM < 16
neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, char *buffer, bool skipFsync)
#else
@@ -3273,11 +3271,12 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
#endif
#endif
}
#endif
#if PG_MAJORVERSION_NUM >= 17
void
static void
neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
const void **buffers, BlockNumber nblocks, bool skipFsync)
{
@@ -3327,7 +3326,7 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
/*
* neon_nblocks() -- Get the number of blocks stored in a relation.
*/
BlockNumber
static BlockNumber
neon_nblocks(SMgrRelation reln, ForkNumber forknum)
{
NeonResponse *resp;
@@ -3464,7 +3463,7 @@ neon_dbsize(Oid dbNode)
/*
* neon_truncate() -- Truncate relation to specified number of blocks.
*/
void
static void
neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks)
{
XLogRecPtr lsn;
@@ -3533,7 +3532,7 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks)
* crash before the next checkpoint syncs the newly-inactive segment, that
* segment may survive recovery, reintroducing unwanted data into the table.
*/
void
static void
neon_immedsync(SMgrRelation reln, ForkNumber forknum)
{
switch (reln->smgr_relpersistence)
@@ -3563,8 +3562,8 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
}
#if PG_MAJORVERSION_NUM >= 17
void
neon_regisersync(SMgrRelation reln, ForkNumber forknum)
static void
neon_registersync(SMgrRelation reln, ForkNumber forknum)
{
switch (reln->smgr_relpersistence)
{
@@ -3748,6 +3747,8 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
SlruKind kind;
int n_blocks;
shardno_t shard_no = 0; /* All SLRUs are at shard 0 */
NeonResponse *resp;
NeonGetSlruSegmentRequest request;
/*
* Compute a request LSN to use, similar to neon_get_request_lsns() but the
@@ -3786,8 +3787,7 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
else
return -1;
NeonResponse *resp;
NeonGetSlruSegmentRequest request = {
request = (NeonGetSlruSegmentRequest) {
.req.tag = T_NeonGetSlruSegmentRequest,
.req.lsn = request_lsn,
.req.not_modified_since = not_modified_since,
@@ -3894,7 +3894,7 @@ static const struct f_smgr neon_smgr =
.smgr_truncate = neon_truncate,
.smgr_immedsync = neon_immedsync,
#if PG_MAJORVERSION_NUM >= 17
.smgr_registersync = neon_regisersync,
.smgr_registersync = neon_registersync,
#endif
.smgr_start_unlogged_build = neon_start_unlogged_build,
.smgr_finish_unlogged_build_phase_1 = neon_finish_unlogged_build_phase_1,

View File

@@ -252,8 +252,6 @@ WalProposerPoll(WalProposer *wp)
/* timeout expired: poll state */
if (rc == 0 || TimeToReconnect(wp, now) <= 0)
{
TimestampTz now;
/*
* If no WAL was generated during timeout (and we have already
* collected the quorum), then send empty keepalive message
@@ -269,8 +267,7 @@ WalProposerPoll(WalProposer *wp)
now = wp->api.get_current_timestamp(wp);
for (int i = 0; i < wp->n_safekeepers; i++)
{
Safekeeper *sk = &wp->safekeeper[i];
sk = &wp->safekeeper[i];
if (TimestampDifferenceExceeds(sk->latestMsgReceivedAt, now,
wp->config->safekeeper_connection_timeout))
{
@@ -1080,7 +1077,7 @@ SendProposerElected(Safekeeper *sk)
ProposerElected msg;
TermHistory *th;
term_t lastCommonTerm;
int i;
int idx;
/* Now that we are ready to send it's a good moment to create WAL reader */
wp->api.wal_reader_allocate(sk);
@@ -1099,15 +1096,15 @@ SendProposerElected(Safekeeper *sk)
/* We must start somewhere. */
Assert(wp->propTermHistory.n_entries >= 1);
for (i = 0; i < Min(wp->propTermHistory.n_entries, th->n_entries); i++)
for (idx = 0; idx < Min(wp->propTermHistory.n_entries, th->n_entries); idx++)
{
if (wp->propTermHistory.entries[i].term != th->entries[i].term)
if (wp->propTermHistory.entries[idx].term != th->entries[idx].term)
break;
/* term must begin everywhere at the same point */
Assert(wp->propTermHistory.entries[i].lsn == th->entries[i].lsn);
Assert(wp->propTermHistory.entries[idx].lsn == th->entries[idx].lsn);
}
i--; /* step back to the last common term */
if (i < 0)
idx--; /* step back to the last common term */
if (idx < 0)
{
/* safekeeper is empty or no common point, start from the beginning */
sk->startStreamingAt = wp->propTermHistory.entries[0].lsn;
@@ -1128,14 +1125,14 @@ SendProposerElected(Safekeeper *sk)
* proposer, LSN it is currently writing, but then we just pick
* safekeeper pos as it obviously can't be higher.
*/
if (wp->propTermHistory.entries[i].term == wp->propTerm)
if (wp->propTermHistory.entries[idx].term == wp->propTerm)
{
sk->startStreamingAt = sk->voteResponse.flushLsn;
}
else
{
XLogRecPtr propEndLsn = wp->propTermHistory.entries[i + 1].lsn;
XLogRecPtr skEndLsn = (i + 1 < th->n_entries ? th->entries[i + 1].lsn : sk->voteResponse.flushLsn);
XLogRecPtr propEndLsn = wp->propTermHistory.entries[idx + 1].lsn;
XLogRecPtr skEndLsn = (idx + 1 < th->n_entries ? th->entries[idx + 1].lsn : sk->voteResponse.flushLsn);
sk->startStreamingAt = Min(propEndLsn, skEndLsn);
}
@@ -1149,7 +1146,7 @@ SendProposerElected(Safekeeper *sk)
msg.termHistory = &wp->propTermHistory;
msg.timelineStartLsn = wp->timelineStartLsn;
lastCommonTerm = i >= 0 ? wp->propTermHistory.entries[i].term : 0;
lastCommonTerm = idx >= 0 ? wp->propTermHistory.entries[idx].term : 0;
wp_log(LOG,
"sending elected msg to node " UINT64_FORMAT " term=" UINT64_FORMAT ", startStreamingAt=%X/%X (lastCommonTerm=" UINT64_FORMAT "), termHistory.n_entries=%u to %s:%s, timelineStartLsn=%X/%X",
sk->greetResponse.nodeId, msg.term, LSN_FORMAT_ARGS(msg.startStreamingAt), lastCommonTerm, msg.termHistory->n_entries, sk->host, sk->port, LSN_FORMAT_ARGS(msg.timelineStartLsn));
@@ -1641,7 +1638,7 @@ UpdateDonorShmem(WalProposer *wp)
* Process AppendResponse message from safekeeper.
*/
static void
HandleSafekeeperResponse(WalProposer *wp, Safekeeper *sk)
HandleSafekeeperResponse(WalProposer *wp, Safekeeper *fromsk)
{
XLogRecPtr candidateTruncateLsn;
XLogRecPtr newCommitLsn;
@@ -1660,7 +1657,7 @@ HandleSafekeeperResponse(WalProposer *wp, Safekeeper *sk)
* and WAL is committed by the quorum. BroadcastAppendRequest() should be
* called to notify safekeepers about the new commitLsn.
*/
wp->api.process_safekeeper_feedback(wp, sk);
wp->api.process_safekeeper_feedback(wp, fromsk);
/*
* Try to advance truncateLsn -- the last record flushed to all

View File

@@ -725,7 +725,7 @@ extern void WalProposerBroadcast(WalProposer *wp, XLogRecPtr startpos, XLogRecPt
extern void WalProposerPoll(WalProposer *wp);
extern void WalProposerFree(WalProposer *wp);
extern WalproposerShmemState *GetWalpropShmemState();
extern WalproposerShmemState *GetWalpropShmemState(void);
/*
* WaitEventSet API doesn't allow to remove socket, so walproposer_pg uses it to
@@ -745,7 +745,7 @@ extern TimeLineID walprop_pg_get_timeline_id(void);
* catch logging.
*/
#ifdef WALPROPOSER_LIB
extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...);
extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...) pg_attribute_printf(3, 4);
#define wp_log(elevel, fmt, ...) WalProposerLibLog(wp, elevel, fmt, ## __VA_ARGS__)
#else
#define wp_log(elevel, fmt, ...) elog(elevel, WP_LOG_PREFIX fmt, ## __VA_ARGS__)

View File

@@ -286,6 +286,9 @@ safekeepers_cmp(char *old, char *new)
static void
assign_neon_safekeepers(const char *newval, void *extra)
{
char *newval_copy;
char *oldval;
if (!am_walproposer)
return;
@@ -295,8 +298,8 @@ assign_neon_safekeepers(const char *newval, void *extra)
}
/* Copy values because we will modify them in split_safekeepers_list() */
char *newval_copy = pstrdup(newval);
char *oldval = pstrdup(wal_acceptors_list);
newval_copy = pstrdup(newval);
oldval = pstrdup(wal_acceptors_list);
/*
* TODO: restarting through FATAL is stupid and introduces 1s delay before
@@ -538,7 +541,7 @@ nwp_shmem_startup_hook(void)
}
WalproposerShmemState *
GetWalpropShmemState()
GetWalpropShmemState(void)
{
Assert(walprop_shared != NULL);
return walprop_shared;

View File

@@ -44,27 +44,6 @@ infobits_desc(StringInfo buf, uint8 infobits, const char *keyname)
appendStringInfoString(buf, "]");
}
static void
truncate_flags_desc(StringInfo buf, uint8 flags)
{
appendStringInfoString(buf, "flags: [");
if (flags & XLH_TRUNCATE_CASCADE)
appendStringInfoString(buf, "CASCADE, ");
if (flags & XLH_TRUNCATE_RESTART_SEQS)
appendStringInfoString(buf, "RESTART_SEQS, ");
if (buf->data[buf->len - 1] == ' ')
{
/* Truncate-away final unneeded ", " */
Assert(buf->data[buf->len - 2] == ',');
buf->len -= 2;
buf->data[buf->len] = '\0';
}
appendStringInfoString(buf, "]");
}
void
neon_rm_desc(StringInfo buf, XLogReaderState *record)
{

View File

@@ -136,7 +136,7 @@ static bool redo_block_filter(XLogReaderState *record, uint8 block_id);
static void GetPage(StringInfo input_message);
static void Ping(StringInfo input_message);
static ssize_t buffered_read(void *buf, size_t count);
static void CreateFakeSharedMemoryAndSemaphores();
static void CreateFakeSharedMemoryAndSemaphores(void);
static BufferTag target_redo_tag;
@@ -170,6 +170,40 @@ close_range_syscall(unsigned int start_fd, unsigned int count, unsigned int flag
return syscall(__NR_close_range, start_fd, count, flags);
}
static PgSeccompRule allowed_syscalls[] =
{
/* Hard requirements */
PG_SCMP_ALLOW(exit_group),
PG_SCMP_ALLOW(pselect6),
PG_SCMP_ALLOW(read),
PG_SCMP_ALLOW(select),
PG_SCMP_ALLOW(write),
/* Memory allocation */
PG_SCMP_ALLOW(brk),
#ifndef MALLOC_NO_MMAP
/* TODO: musl doesn't have mallopt */
PG_SCMP_ALLOW(mmap),
PG_SCMP_ALLOW(munmap),
#endif
/*
* getpid() is called on assertion failure, in ExceptionalCondition.
* It's not really needed, but seems pointless to hide it either. The
* system call unlikely to expose a kernel vulnerability, and the PID
* is stored in MyProcPid anyway.
*/
PG_SCMP_ALLOW(getpid),
/* Enable those for a proper shutdown. */
#if 0
PG_SCMP_ALLOW(munmap),
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), /* shm_unlink */
#endif
};
static void
enter_seccomp_mode(void)
{
@@ -183,44 +217,12 @@ enter_seccomp_mode(void)
(errcode(ERRCODE_SYSTEM_ERROR),
errmsg("seccomp: could not close files >= fd 3")));
PgSeccompRule syscalls[] =
{
/* Hard requirements */
PG_SCMP_ALLOW(exit_group),
PG_SCMP_ALLOW(pselect6),
PG_SCMP_ALLOW(read),
PG_SCMP_ALLOW(select),
PG_SCMP_ALLOW(write),
/* Memory allocation */
PG_SCMP_ALLOW(brk),
#ifndef MALLOC_NO_MMAP
/* TODO: musl doesn't have mallopt */
PG_SCMP_ALLOW(mmap),
PG_SCMP_ALLOW(munmap),
#endif
/*
* getpid() is called on assertion failure, in ExceptionalCondition.
* It's not really needed, but seems pointless to hide it either. The
* system call unlikely to expose a kernel vulnerability, and the PID
* is stored in MyProcPid anyway.
*/
PG_SCMP_ALLOW(getpid),
/* Enable those for a proper shutdown.
PG_SCMP_ALLOW(munmap),
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), // shm_unlink
*/
};
#ifdef MALLOC_NO_MMAP
/* Ask glibc not to use mmap() */
mallopt(M_MMAP_MAX, 0);
#endif
seccomp_load_rules(syscalls, lengthof(syscalls));
seccomp_load_rules(allowed_syscalls, lengthof(allowed_syscalls));
}
#endif /* HAVE_LIBSECCOMP */
@@ -449,7 +451,7 @@ WalRedoMain(int argc, char *argv[])
* half-initialized postgres.
*/
static void
CreateFakeSharedMemoryAndSemaphores()
CreateFakeSharedMemoryAndSemaphores(void)
{
PGShmemHeader *shim = NULL;
PGShmemHeader *hdr;
@@ -992,7 +994,7 @@ redo_block_filter(XLogReaderState *record, uint8 block_id)
* If this block isn't one we are currently restoring, then return 'true'
* so that this gets ignored
*/
return !BUFFERTAGS_EQUAL(target_tag, target_redo_tag);
return !BufferTagsEqual(&target_tag, &target_redo_tag);
}
/*

29
poetry.lock generated
View File

@@ -2095,6 +2095,7 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@@ -2103,6 +2104,8 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@@ -2584,6 +2587,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2729,21 +2733,22 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "responses"
version = "0.21.0"
version = "0.25.3"
description = "A utility library for mocking out the `requests` Python library."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "responses-0.21.0-py3-none-any.whl", hash = "sha256:2dcc863ba63963c0c3d9ee3fa9507cbe36b7d7b0fccb4f0bdfd9e96c539b1487"},
{file = "responses-0.21.0.tar.gz", hash = "sha256:b82502eb5f09a0289d8e209e7bad71ef3978334f56d09b444253d5ad67bf5253"},
{file = "responses-0.25.3-py3-none-any.whl", hash = "sha256:521efcbc82081ab8daa588e08f7e8a64ce79b91c39f6e62199b19159bea7dbcb"},
{file = "responses-0.25.3.tar.gz", hash = "sha256:617b9247abd9ae28313d57a75880422d55ec63c29d33d629697590a034358dba"},
]
[package.dependencies]
requests = ">=2.0,<3.0"
urllib3 = ">=1.25.10"
pyyaml = "*"
requests = ">=2.30.0,<3.0"
urllib3 = ">=1.25.10,<3.0"
[package.extras]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-localserver", "types-mock", "types-requests"]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"]
[[package]]
name = "rfc3339-validator"
@@ -3137,6 +3142,16 @@ files = [
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},

View File

@@ -1,11 +1,12 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import enum
import os
import subprocess
import sys
from typing import List
@enum.unique
@@ -55,12 +56,12 @@ def mypy() -> str:
return "poetry run mypy"
def get_commit_files() -> List[str]:
def get_commit_files() -> list[str]:
files = subprocess.check_output("git diff --cached --name-only --diff-filter=ACM".split())
return files.decode().splitlines()
def check(name: str, suffix: str, cmd: str, changed_files: List[str], no_color: bool = False):
def check(name: str, suffix: str, cmd: str, changed_files: list[str], no_color: bool = False):
print(f"Checking: {name} ", end="")
applicable_files = list(filter(lambda fname: fname.strip().endswith(suffix), changed_files))
if not applicable_files:

View File

@@ -39,7 +39,7 @@ http.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
hyper0.workspace = true
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
hyper = { workspace = true, features = ["server", "http1", "http2"] }
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
http-body-util = { version = "0.1" }
indexmap.workspace = true

View File

@@ -1,24 +1,18 @@
use crate::{
auth,
cache::Cached,
compute,
auth, compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{self, provider::NodeInfo, CachedNodeInfo},
control_plane::{self, provider::NodeInfo},
error::{ReportableError, UserFacingError},
proxy::connect_compute::ComputeConnectBackend,
stream::PqStream,
waiters,
};
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
#[derive(Debug, Error)]
pub(crate) enum WebAuthError {
#[error(transparent)]
@@ -31,11 +25,6 @@ pub(crate) enum WebAuthError {
Io(#[from] std::io::Error),
}
#[derive(Debug)]
pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
}
impl UserFacingError for WebAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
@@ -68,40 +57,7 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
}
pub(crate) async fn authenticate(
&self,
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<ConsoleRedirectNodeInfo> {
authenticate(ctx, auth_config, &self.console_uri, client)
.await
.map(ConsoleRedirectNodeInfo)
}
}
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.0.clone()))
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
}
}
async fn authenticate(
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,

View File

@@ -571,7 +571,7 @@ mod tests {
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
use hyper1::service::service_fn;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey;
@@ -736,7 +736,7 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
});
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let server = hyper1::server::conn::http1::Builder::new();
let server = hyper::server::conn::http1::Builder::new();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {

View File

@@ -8,7 +8,6 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::WebAuthError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
@@ -20,8 +19,9 @@ use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::control_plane::AuthSecret;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::provider::{CachedRoleSecret, ControlPlaneBackend};
use crate::control_plane::{AuthSecret, NodeInfo};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
@@ -31,22 +31,48 @@ use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
control_plane::{self, provider::CachedNodeInfo, Api},
stream,
control_plane::{
self,
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
/// The [crate::serverless] module can authenticate either using control-plane
/// to get authentication state, or by using JWKs stored in the filesystem.
pub enum ServerlessBackend<'a> {
/// Cloud API (V2).
ControlPlane(&'a ControlPlaneBackend),
/// Local proxy uses configured auth credentials and does not wake compute
Local(&'a LocalBackend),
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
#[cfg(test)]
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum Backend<'a, T, D> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
/// Authentication via a web browser.
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
#[cfg(test)]
pub(crate) trait TestBackend: Send + Sync + 'static {
@@ -64,20 +90,63 @@ impl Clone for Box<dyn TestBackend> {
}
}
impl std::fmt::Display for ControlPlaneBackend {
impl std::fmt::Display for Backend<'_, (), ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
Self::ControlPlane(api, ()) => match &**api {
ControlPlaneBackend::Management(endpoint) => fmt
.debug_tuple("ControlPlane::Management")
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::ConsoleRedirect(url, ()) => fmt
.debug_tuple("ConsoleRedirect")
.field(&url.as_str())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneBackend::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
}
}
impl<T, D> Backend<'_, T, D> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(MaybeOwned::Borrowed(c), x),
Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(MaybeOwned::Borrowed(c), x),
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
}
}
}
impl<'a, T, D> Backend<'a, T, D> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
/// a function to a contained value.
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> {
match self {
Self::ControlPlane(c, x) => Backend::ControlPlane(c, f(x)),
Self::ConsoleRedirect(c, x) => Backend::ConsoleRedirect(c, x),
Self::Local(l) => Backend::Local(l),
}
}
}
impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T, D>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::ConsoleRedirect(c, x) => Ok(Backend::ConsoleRedirect(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
@@ -170,6 +239,7 @@ impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
@@ -193,7 +263,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
),
password_weight,
);
@@ -267,6 +337,7 @@ async fn auth_quirks(
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
@@ -342,79 +413,133 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl ControlPlaneBackend {
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::ConsoleRedirect(_, ()) => "web",
Self::Local(_) => "local",
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
&self,
self,
ctx: &RequestMonitoring,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ControlPlaneComputeBackend> {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<Backend<'a, ComputeCredentials, NodeInfo>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
let credentials = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Backend::ControlPlane(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::ConsoleRedirect(url, ()) => {
info!("performing web authentication");
let info = console_redirect::authenticate(ctx, config, &url, client).await?;
Backend::ConsoleRedirect(url, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
}
};
info!("user successfully authenticated");
Ok(ControlPlaneComputeBackend {
api: self,
creds: credentials,
})
}
pub(crate) fn attach_to_credentials(
&self,
creds: ComputeCredentials,
) -> ControlPlaneComputeBackend {
ControlPlaneComputeBackend { api: self, creds }
Ok(res)
}
}
pub struct ControlPlaneComputeBackend<'a> {
api: &'a ControlPlaneBackend,
creds: ComputeCredentials,
impl Backend<'_, ComputeUserInfo, &()> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::ConsoleRedirect(_, ()) => Ok(Cached::new_uncached(None)),
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_allowed_ips_and_secret(
&self,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_ips_and_secret(ctx, user_info).await
}
Self::ConsoleRedirect(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneComputeBackend<'_> {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> {
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.api.wake_compute(ctx, &self.creds.info).await
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::ConsoleRedirect(_, info) => Ok(Cached::new_uncached(info.clone())),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::ConsoleRedirect(_, _) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}
#[async_trait::async_trait]
impl ComputeConnectBackend for LocalBackend {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> {
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
Ok(Cached::new_uncached(self.node_info.clone()))
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::ConsoleRedirect(_, ()) => {
unreachable!("web auth flow doesn't support waking the compute")
}
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&ComputeCredentialKeys::None
match self {
Self::ControlPlane(_, creds) => &creds.keys,
Self::ConsoleRedirect(_, ()) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
}
}
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::ServerlessBackend;
pub use backend::Backend;
mod credentials;
pub(crate) use credentials::{

View File

@@ -6,12 +6,9 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
@@ -135,7 +132,6 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
@@ -197,7 +193,6 @@ async fn main() -> anyhow::Result<()> {
let task = serverless::task_main(
config,
auth::ServerlessBackend::Local(auth_backend),
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
@@ -262,6 +257,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
@@ -288,13 +286,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> {
let auth_backend = LocalBackend::new(args.compute);
Ok(Box::leak(Box::new(auth_backend)))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;

View File

@@ -10,7 +10,7 @@ use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::remote_storage_from_toml;
@@ -21,7 +21,6 @@ use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::control_plane;
use proxy::control_plane::provider::ControlPlaneBackend;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
@@ -312,12 +311,8 @@ async fn main() -> anyhow::Result<()> {
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
match auth_backend {
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
};
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", args.aws_region);
let region_provider =
@@ -464,41 +459,24 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth::ServerlessBackend::ControlPlane(auth_backend),
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
}
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
client_tasks.spawn(proxy::context::parquet::worker(
@@ -528,38 +506,40 @@ async fn main() -> anyhow::Result<()> {
));
}
if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
}
}
@@ -630,6 +610,73 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}
let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::provider::ControlPlaneBackend::Management(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
}
#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
};
let config::ConcurrencyLockOptions {
shards,
limiter,
@@ -679,8 +726,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = ProxyConfig {
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
@@ -693,97 +741,13 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
};
let config = Box::leak(Box::new(config));
}));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<Either<&'static ControlPlaneBackend, &'static ConsoleRedirectBackend>> {
match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
AuthBackendType::Web => {
let url = args.uri.parse()?;
let backend = ConsoleRedirectBackend::new(url);
let config = Box::leak(Box::new(backend));
Ok(Either::Right(config))
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;

View File

@@ -1,5 +1,8 @@
use crate::{
auth::backend::{jwt::JwkCache, AuthRateLimiter},
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -26,6 +29,7 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::Backend<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,

View File

@@ -1,161 +0,0 @@
use crate::auth::backend::ConsoleRedirectBackend;
use crate::config::ProxyConfig;
use crate::metrics::Protocol;
use crate::proxy::{prepare_client_connection, transition_connection, ClientRequestError};
use crate::{
cancellation::CancellationHandlerMain,
context::RequestMonitoring,
metrics::{Metrics, NumClientConnectionsGuard},
proxy::handshake::{handshake, HandshakeData},
};
use futures::TryFutureExt;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{info, Instrument};
use crate::proxy::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
pub async fn task_main(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
super::connection_loop(
config,
listener,
cancellation_token,
Protocol::Tcp,
C {
config,
backend,
cancellation_handler,
},
)
.await
}
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
cancellation_handler: Arc<CancellationHandlerMain>,
}
impl super::ConnHandler for C {
async fn handle(
self,
session_id: uuid::Uuid,
peer_addr: IpAddr,
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
let ctx = RequestMonitoring::new(session_id, peer_addr, Protocol::Tcp, &self.config.region);
let span = ctx.span();
let startup = Box::pin(
handle_client(
self.config,
self.backend,
&ctx,
self.cancellation_handler,
socket,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;
transition_connection(ctx, res).await;
}
}
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
.await
.map(|()| None)?)
}
};
drop(pause);
ctx.set_db_options(params.clone());
let user_info = match backend
.authenticate(ctx, &config.authentication_config, &mut stream)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
return stream.throw_error(e).await?;
}
};
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.allow_self_signed_compute,
config.wake_compute_retry_config,
config.connect_to_compute_retry_config,
)
.or_else(|e| stream.throw_error(e))
.await?;
let session = cancellation_handler.get_session();
prepare_client_connection(&node, &session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node,
_req: request_gauge,
_conn: conn_gauge,
_cancel: session,
}))
}

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, bail};
use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use hyper0::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use measured::{text::BufferedTextEncoder, MetricGroup};
use metrics::NeonMetrics;
use std::{
@@ -21,7 +21,7 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
}
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper::Body, ApiError> {
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper0::Body, ApiError> {
let state = Arc::new(Mutex::new(PrometheusHandler {
encoder: BufferedTextEncoder::new(),
metrics,
@@ -45,7 +45,7 @@ pub async fn task_main(
let service = || RouterService::new(make_router(metrics).build()?);
hyper::Server::from_tcp(http_listener)?
hyper0::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;

View File

@@ -9,7 +9,7 @@ use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper1::body::Body;
use hyper::body::Body;
use serde::de::DeserializeOwned;
pub(crate) use reqwest::{Request, Response};

View File

@@ -82,25 +82,19 @@
impl_trait_overcaptures,
)]
use std::{convert::Infallible, future::Future, net::IpAddr};
use std::convert::Infallible;
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
use protocol2::{get_client_conn_info, ChainRW};
use proxy::run_until_cancelled;
use tokio::{net::TcpStream, task::JoinError};
use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::{error, warn};
use uuid::Uuid;
extern crate hyper0 as hyper;
use tracing::warn;
pub mod auth;
pub mod cache;
pub mod cancellation;
pub mod compute;
pub mod config;
pub mod console_redirect_proxy;
pub mod context;
pub mod control_plane;
pub mod error;
@@ -279,81 +273,3 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
pub(crate) trait ConnHandler: Clone + Send + 'static {
fn handle(
self,
session_id: Uuid,
peer_addr: IpAddr,
stream: ChainRW<TcpStream>,
conn_gauge: metrics::NumClientConnectionsGuard<'static>,
) -> impl Future<Output = ()> + Send;
}
/// Accept connections, parse the proxy-protocol v2 header and spawn a tracked connection task.
pub(crate) async fn connection_loop<C>(
config: &'static config::ProxyConfig,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
protocol: metrics::Protocol,
conn_handler: C,
) -> anyhow::Result<()>
where
C: ConnHandler,
{
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = metrics::Metrics::get()
.proxy
.client_connections
.guard(protocol);
let session_id = uuid::Uuid::new_v4();
let conn_handler = conn_handler.clone();
tracing::info!(protocol = protocol.as_str(), %session_id, "accepted new TCP connection");
connections.spawn(async move {
let (socket, peer_addr) = match get_client_conn_info(socket, config.proxy_protocol_v2).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Ok((socket, Some(addr))) => (socket, addr),
Ok((socket, None)) => (socket, peer_addr.ip()),
};
match socket.inner.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
}
};
conn_handler.handle(
session_id,
peer_addr,
socket,
conn_gauge,
).await;
});
}
connections.close();
drop(listener);
// Drain connections
connections.wait().await;
Ok(())
}

View File

@@ -2,18 +2,15 @@
use std::{
io,
net::{IpAddr, SocketAddr},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use crate::config::ProxyProtocolV2;
pin_project! {
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
pub(crate) struct ChainRW<T> {
@@ -63,23 +60,7 @@ const HEADER: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
pub(crate) async fn get_client_conn_info<T: AsyncRead + Unpin>(
socket: T,
proxy_protocol_v2: ProxyProtocolV2,
) -> anyhow::Result<(ChainRW<T>, Option<IpAddr>)> {
match read_proxy_protocol(socket).await? {
(_socket, None) if proxy_protocol_v2 == ProxyProtocolV2::Required => {
bail!("missing required proxy protocol header");
}
(_socket, Some(_)) if proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
bail!("proxy protocol header not supported");
}
(socket, Some(addr)) => Ok((socket, Some(addr.ip()))),
(socket, None) => Ok((socket, None)),
}
}
async fn read_proxy_protocol<T: AsyncRead + Unpin>(
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
mut read: T,
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
let mut buf = BytesMut::with_capacity(128);

View File

@@ -10,16 +10,16 @@ pub(crate) mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use crate::control_plane::provider::ControlPlaneBackend;
use crate::metrics::Protocol;
use crate::config::ProxyProtocolV2;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain},
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
@@ -31,7 +31,6 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::net::IpAddr;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -62,7 +61,6 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -72,91 +70,109 @@ pub async fn task_main(
info!("proxy has shut down");
}
super::connection_loop(
config,
listener,
cancellation_token,
Protocol::Tcp,
C {
config,
auth_backend,
cancellation_handler,
endpoint_rate_limiter,
},
)
.await
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
let connections = tokio_util::task::task_tracker::TaskTracker::new();
impl super::ConnHandler for C {
async fn handle(
self,
session_id: uuid::Uuid,
peer_addr: IpAddr,
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&self.config.region,
);
let span = ctx.span();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let startup = Box::pin(
handle_client(
self.config,
self.auth_backend,
&ctx,
self.cancellation_handler,
socket,
ClientMode::Tcp,
self.endpoint_rate_limiter,
conn_gauge,
)
.instrument(span.clone()),
);
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let res = startup.await;
transition_connection(ctx, res).await;
}
}
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
pub(crate) async fn transition_connection<S: AsyncRead + AsyncWrite + Unpin>(
ctx: RequestMonitoring,
res: Result<Option<ProxyPassthrough<S>>, ClientRequestError>,
) {
let span = ctx.span();
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
};
match socket.inner.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
return;
}
};
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let startup = Box::pin(
handle_client(
config,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone()),
);
let res = startup.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
}
}
}
}
}
});
}
connections.close();
drop(listener);
// Drain connections
connections.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
@@ -227,17 +243,15 @@ impl ReportableError for ClientRequestError {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
@@ -271,17 +285,21 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names);
let result = config
.auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
};
let user = user_info.user.clone();
let user_info = match auth_backend
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
user_info,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
@@ -335,7 +353,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
pub(crate) async fn prepare_client_connection<P>(
async fn prepare_client_connection<P>(
node: &compute::PostgresConnection,
session: &cancellation::Session<P>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation::{self, CancellationHandlerMainInternal},
cancellation,
compute::PostgresConnection,
control_plane::messages::MetricsAuxInfo,
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
@@ -57,17 +57,17 @@ pub(crate) async fn proxy_pass(
Ok(())
}
pub(crate) struct ProxyPassthrough<S> {
pub(crate) struct ProxyPassthrough<P, S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
pub(crate) _cancel: cancellation::Session<CancellationHandlerMainInternal>,
pub(crate) _cancel: cancellation::Session<P>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {

View File

@@ -8,20 +8,18 @@ use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend,
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
};
use crate::config::{CertResolver, ProxyProtocolV2, RetryConfig};
use crate::config::{CertResolver, RetryConfig};
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
use crate::control_plane::provider::{
CachedAllowedIps, CachedRoleSecret, ControlPlaneBackend, NodeInfoCache,
};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::protocol2::get_client_conn_info;
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use auth::backend::ControlPlaneComputeBackend;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
@@ -178,7 +176,7 @@ async fn dummy_proxy(
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = get_client_conn_info(client, ProxyProtocolV2::Supported).await?;
let (client, _) = read_proxy_protocol(client).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
@@ -554,19 +552,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> ControlPlaneComputeBackend<'static> {
let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new(
mechanism.clone(),
))));
api.attach_to_credentials(ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
) -> auth::Backend<'static, ComputeCredentials, &()> {
let user_info = auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneBackend::Test(Box::new(mechanism.clone()))),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::Password("password".into()),
},
keys: ComputeCredentialKeys::Password("password".into()),
})
);
user_info
}
#[tokio::test]

View File

@@ -7,7 +7,7 @@ use crate::metrics::{
WakeupFailureKind,
};
use crate::proxy::retry::{retry_after, should_retry};
use hyper1::StatusCode;
use hyper::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;

View File

@@ -8,16 +8,16 @@ use tracing::{field::display, info};
use crate::{
auth::{
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError, ServerlessBackend,
check_peer_addr_is_in_list, AuthError,
},
compute,
config::ProxyConfig,
config::{AuthenticationConfig, ProxyConfig},
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
locks::ApiLocks,
provider::ApiLockError,
Api, CachedNodeInfo,
CachedNodeInfo,
},
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
@@ -38,7 +38,6 @@ pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: ServerlessBackend<'static>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -46,20 +45,18 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let cplane = match &self.auth_backend {
ServerlessBackend::ControlPlane(cplane) => cplane,
ServerlessBackend::Local(_local) => {
return Err(AuthError::bad_auth_method(
"password authentication not supported by local_proxy",
))
}
};
let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?;
if self.config.authentication_config.ip_allowlist_check_enabled
let user_info = user_info.clone();
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
@@ -72,12 +69,13 @@ impl PoolingBackend {
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => cplane.get_role_secret(ctx, user_info).await?,
None => backend.get_role_secret(ctx).await?,
};
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
config,
secret,
&user_info.endpoint,
true,
@@ -89,13 +87,9 @@ impl PoolingBackend {
}
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
password,
secret,
)
.await?;
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -107,7 +101,7 @@ impl PoolingBackend {
}
};
res.map(|key| ComputeCredentials {
info: user_info.clone(),
info: user_info,
keys: key,
})
}
@@ -115,13 +109,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<(), AuthError> {
match &self.auth_backend {
ServerlessBackend::ControlPlane(console) => {
self.config
.authentication_config
match &self.config.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
.jwks_cache
.check_jwt(
ctx,
@@ -135,9 +129,11 @@ impl PoolingBackend {
Ok(())
}
ServerlessBackend::Local(_) => {
self.config
.authentication_config
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
.jwks_cache
.check_jwt(
ctx,
@@ -180,41 +176,21 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
match &self.auth_backend {
ServerlessBackend::ControlPlane(cplane) => {
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&cplane.attach_to_credentials(keys),
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
ServerlessBackend::Local(local_proxy) => {
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&**local_proxy,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
let backend = self.config.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
// Wake up the destination if needed
@@ -224,13 +200,6 @@ impl PoolingBackend {
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
let cplane = match &self.auth_backend {
ServerlessBackend::Local(_) => {
panic!("connect to local_proxy should not be called if we are already local_proxy")
}
ServerlessBackend::ControlPlane(cplane) => cplane,
};
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -239,11 +208,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = cplane.attach_to_credentials(ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
@@ -285,7 +257,7 @@ pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
H2(#[from] hyper::Error),
}
impl ReportableError for HttpConnError {
@@ -509,7 +481,7 @@ async fn connect_http2(
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)

View File

@@ -1,5 +1,5 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
@@ -18,9 +18,9 @@ use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {

View File

@@ -11,7 +11,7 @@ use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -67,12 +67,12 @@ impl HttpErrorBody {
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
) -> Response<BoxBody<Bytes, hyper::Error>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
@@ -90,7 +90,7 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;

View File

@@ -16,12 +16,13 @@ use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper1::body::Incoming;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
@@ -31,29 +32,28 @@ use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::auth::ServerlessBackend;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::{Metrics, Protocol};
use crate::protocol2::ChainRW;
use crate::metrics::Metrics;
use crate::protocol2::{read_proxy_protocol, ChainRW};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, warn, Instrument};
use tracing::{error, info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: ServerlessBackend<'static>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -107,7 +107,6 @@ pub async fn task_main(
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -123,100 +122,81 @@ pub async fn task_main(
}
};
let requests = TaskTracker::new();
requests.close(); // allows `requests.wait to complete`
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
crate::connection_loop(
config,
ws_listener,
cancellation_token.clone(),
Protocol::Http,
C {
config,
backend,
cancellation_handler,
endpoint_rate_limiter,
tls_acceptor,
requests: requests.clone(),
cancellation_token,
},
)
.await?;
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");
continue;
}
let conn_id = uuid::Uuid::new_v4();
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
requests.wait().await;
Ok(())
}
#[derive(Clone)]
struct C {
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
requests: TaskTracker,
cancellation_token: CancellationToken,
}
impl super::ConnHandler for C {
#[instrument(name = "http_conn", skip_all, fields(conn_id))]
async fn handle(
self,
conn_id: uuid::Uuid,
peer_addr: IpAddr,
stream: ChainRW<TcpStream>,
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
) {
// try and close an old HTTP connection.
// picked at random
let n_connections = Metrics::get()
.proxy
.client_connections
.sample(crate::metrics::Protocol::Http);
tracing::trace!(?n_connections, threshold = ?self.config.http_config.client_conn_threshold, "check");
if n_connections > self.config.http_config.client_conn_threshold {
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
if n_connections > config.http_config.client_conn_threshold {
tracing::trace!("attempting to cancel a random connection");
if let Some(token) = self.config.http_config.cancel_set.take() {
if let Some(token) = config.http_config.cancel_set.take() {
tracing::debug!("cancelling a random connection");
token.cancel();
}
}
let conn_token = self.cancellation_token.child_token();
let _cancel_guard = self
.config
.http_config
.cancel_set
.insert(conn_id, conn_token.clone());
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
let startup_result = Box::pin(connection_startup(
self.config,
self.tls_acceptor,
conn_id,
stream,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};
let session_id = uuid::Uuid::new_v4();
Box::pin(connection_handler(
self.config,
self.backend,
self.requests,
self.cancellation_handler,
self.endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
conn_id,
))
.await;
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
drop(conn_gauge);
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return;
};
Box::pin(connection_handler(
config,
backend,
connections2,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
peer_addr,
session_id,
))
.await;
}
.instrument(http_conn_span),
);
}
connections.wait().await;
Ok(())
}
pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
@@ -244,14 +224,26 @@ impl MaybeTlsAcceptor for NoTls {
}
}
/// Handles the TLS startup handshake.
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn: ChainRW<TcpStream>,
peer_addr: IpAddr,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(AsyncRW, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
};
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
@@ -310,7 +302,7 @@ async fn connection_handler(
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
@@ -363,7 +355,7 @@ async fn connection_handler(
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper1::Request<Incoming>,
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
@@ -373,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -385,10 +377,6 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
};
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -406,7 +394,6 @@ async fn request_handler(
async move {
if let Err(e) = websocket::serve_websocket(
config,
auth_backend,
ctx,
websocket,
cancellation_handler,

View File

@@ -12,14 +12,14 @@ use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use hyper::body::Body;
use hyper::body::Incoming;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
@@ -45,7 +45,6 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -273,7 +272,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -436,7 +435,7 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper1::Error),
Read(#[from] hyper::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
@@ -477,7 +476,7 @@ struct HttpHeaders {
}
impl HttpHeaders {
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
@@ -530,7 +529,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -553,7 +552,7 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
@@ -578,7 +577,7 @@ async fn handle_db_inner(
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
@@ -624,12 +623,22 @@ async fn handle_db_inner(
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
@@ -671,7 +680,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
stmt.process(config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -689,7 +698,7 @@ async fn handle_db_inner(
}
statements
.process(&config.http_config, cancel, &mut client, parsed_headers)
.process(config, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -729,14 +738,20 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
@@ -774,7 +789,7 @@ async fn handle_auth_broker_inner(
impl QueryData {
async fn process(
self,
config: &'static HttpConfig,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -848,7 +863,7 @@ impl QueryData {
impl BatchQueryData {
async fn process(
self,
config: &'static HttpConfig,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -918,7 +933,7 @@ impl BatchQueryData {
}
async fn query_batch(
config: &'static HttpConfig,
config: &'static ProxyConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
@@ -957,7 +972,7 @@ async fn query_batch(
}
async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig,
config: &'static ProxyConfig,
client: &T,
data: QueryData,
current_size: &mut usize,
@@ -978,9 +993,9 @@ async fn query_to_json<T: GenericClient>(
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > config.max_response_size_bytes {
if *current_size > config.http_config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge(
config.max_response_size_bytes,
config.http_config.max_response_size_bytes,
));
}
}

View File

@@ -1,4 +1,3 @@
use crate::control_plane::provider::ControlPlaneBackend;
use crate::proxy::ErrorSource;
use crate::{
cancellation::CancellationHandlerMain,
@@ -13,7 +12,7 @@ use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper1::upgrade::OnUpgrade;
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
@@ -130,7 +129,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static ControlPlaneBackend,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -147,7 +145,6 @@ pub(crate) async fn serve_websocket(
let res = Box::pin(handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),

View File

@@ -485,49 +485,51 @@ async fn upload_events_chunk(
#[cfg(test)]
mod tests {
use std::{
net::TcpListener,
sync::{Arc, Mutex},
};
use super::*;
use crate::{http, BranchId, EndpointId};
use anyhow::Error;
use chrono::Utc;
use consumption_metrics::{Event, EventChunk};
use hyper::{
service::{make_service_fn, service_fn},
Body, Response,
};
use http_body_util::BodyExt;
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use std::sync::{Arc, Mutex};
use tokio::net::TcpListener;
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
type Report = EventChunk<'static, Event<Ids, String>>;
let reports: Arc<Mutex<Vec<Report>>> = Arc::default();
let reports = Arc::new(Mutex::new(vec![]));
let reports2 = reports.clone();
let server = hyper::server::Server::from_tcp(listener)
.unwrap()
.serve(make_service_fn(move |_| {
let reports = reports.clone();
async move {
Ok::<_, Error>(service_fn(move |req| {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn({
let reports = reports.clone();
async move {
loop {
if let Ok((stream, _addr)) = listener.accept().await {
let reports = reports.clone();
async move {
let bytes = hyper::body::to_bytes(req.into_body()).await?;
let events: EventChunk<'static, Event<Ids, String>> =
serde_json::from_slice(&bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(Body::from(vec![])))
}
}))
http1::Builder::new()
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: Request<Incoming>| {
let reports = reports.clone();
async move {
let bytes = req.into_body().collect().await?.to_bytes();
let events = serde_json::from_slice(&bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(String::new()))
}
}),
)
.await
.unwrap();
}
}
}));
let addr = server.local_addr();
tokio::spawn(server);
}
});
let metrics = Metrics::default();
let client = http::new_client();
@@ -536,7 +538,7 @@ mod tests {
// no counters have been registered
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports2.lock().unwrap());
let r = std::mem::take(&mut *reports.lock().unwrap());
assert!(r.is_empty());
// register a new counter
@@ -548,7 +550,7 @@ mod tests {
// the counter should be observed despite 0 egress
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports2.lock().unwrap());
let r = std::mem::take(&mut *reports.lock().unwrap());
assert_eq!(r.len(), 1);
assert_eq!(r[0].events.len(), 1);
assert_eq!(r[0].events[0].value, 0);
@@ -558,7 +560,7 @@ mod tests {
// egress should be observered
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports2.lock().unwrap());
let r = std::mem::take(&mut *reports.lock().unwrap());
assert_eq!(r.len(), 1);
assert_eq!(r[0].events.len(), 1);
assert_eq!(r[0].events[0].value, 1);
@@ -568,7 +570,7 @@ mod tests {
// we do not observe the counter
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports2.lock().unwrap());
let r = std::mem::take(&mut *reports.lock().unwrap());
assert!(r.is_empty());
// counter is unregistered

View File

@@ -97,5 +97,8 @@ select = [
"I", # isort
"W", # pycodestyle
"B", # bugbear
"UP032", # f-string
"UP", # pyupgrade
]
[tool.ruff.lint.pyupgrade]
keep-runtime-typing = true # Remove this stanza when we require Python 3.10

View File

@@ -23,6 +23,7 @@ crc32c.workspace = true
fail.workspace = true
hex.workspace = true
humantime.workspace = true
http.workspace = true
hyper0.workspace = true
futures.workspace = true
once_cell.workspace = true

View File

@@ -12,8 +12,8 @@ use metrics::{
core::{AtomicU64, Collector, Desc, GenericCounter, GenericGaugeVec, Opts},
proto::MetricFamily,
register_histogram_vec, register_int_counter, register_int_counter_pair,
register_int_counter_pair_vec, register_int_counter_vec, Gauge, HistogramVec, IntCounter,
IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, Gauge,
HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec,
};
use once_cell::sync::Lazy;
@@ -231,6 +231,14 @@ pub(crate) static EVICTION_EVENTS_COMPLETED: Lazy<IntCounterVec> = Lazy::new(||
.expect("Failed to register metric")
});
pub static NUM_EVICTED_TIMELINES: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"safekeeper_evicted_timelines",
"Number of currently evicted timelines"
)
.expect("Failed to register metric")
});
pub const LABEL_UNKNOWN: &str = "unknown";
/// Labels for traffic metrics.

View File

@@ -631,13 +631,19 @@ impl Timeline {
return Err(e);
}
self.bootstrap(conf, broker_active_set, partial_backup_rate_limiter);
self.bootstrap(
shared_state,
conf,
broker_active_set,
partial_backup_rate_limiter,
);
Ok(())
}
/// Bootstrap new or existing timeline starting background tasks.
pub fn bootstrap(
self: &Arc<Timeline>,
_shared_state: &mut WriteGuardSharedState<'_>,
conf: &SafeKeeperConf,
broker_active_set: Arc<TimelinesSet>,
partial_backup_rate_limiter: RateLimiter,

View File

@@ -15,7 +15,9 @@ use tracing::{debug, info, instrument, warn};
use utils::crashsafe::durable_rename;
use crate::{
metrics::{EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED},
metrics::{
EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED, NUM_EVICTED_TIMELINES,
},
rate_limit::rand_duration,
timeline_manager::{Manager, StateSnapshot},
wal_backup,
@@ -93,6 +95,7 @@ impl Manager {
}
info!("successfully evicted timeline");
NUM_EVICTED_TIMELINES.inc();
}
/// Attempt to restore evicted timeline from remote storage; it must be
@@ -128,6 +131,7 @@ impl Manager {
tokio::time::Instant::now() + rand_duration(&self.conf.eviction_min_resident);
info!("successfully restored evicted timeline");
NUM_EVICTED_TIMELINES.dec();
}
}

View File

@@ -25,7 +25,10 @@ use utils::lsn::Lsn;
use crate::{
control_file::{FileStorage, Storage},
metrics::{MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS},
metrics::{
MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS,
NUM_EVICTED_TIMELINES,
},
rate_limit::{rand_duration, RateLimiter},
recovery::recovery_main,
remove_wal::calc_horizon_lsn,
@@ -251,6 +254,11 @@ pub async fn main_task(
mgr.recovery_task = Some(tokio::spawn(recovery_main(tli, mgr.conf.clone())));
}
// If timeline is evicted, reflect that in the metric.
if mgr.is_offloaded {
NUM_EVICTED_TIMELINES.inc();
}
let last_state = 'outer: loop {
MANAGER_ITERATIONS_TOTAL.inc();
@@ -367,6 +375,11 @@ pub async fn main_task(
mgr.update_wal_removal_end(res);
}
// If timeline is deleted while evicted decrement the gauge.
if mgr.tli.is_cancelled() && mgr.is_offloaded {
NUM_EVICTED_TIMELINES.dec();
}
mgr.set_status(Status::Finished);
}

View File

@@ -165,12 +165,14 @@ impl GlobalTimelines {
match Timeline::load_timeline(&conf, ttid) {
Ok(timeline) => {
let tli = Arc::new(timeline);
let mut shared_state = tli.write_shared_state().await;
TIMELINES_STATE
.lock()
.unwrap()
.timelines
.insert(ttid, tli.clone());
tli.bootstrap(
&mut shared_state,
&conf,
broker_active_set.clone(),
partial_backup_rate_limiter.clone(),
@@ -213,6 +215,7 @@ impl GlobalTimelines {
match Timeline::load_timeline(&conf, ttid) {
Ok(timeline) => {
let tli = Arc::new(timeline);
let mut shared_state = tli.write_shared_state().await;
// TODO: prevent concurrent timeline creation/loading
{
@@ -227,8 +230,13 @@ impl GlobalTimelines {
state.timelines.insert(ttid, tli.clone());
}
tli.bootstrap(&conf, broker_active_set, partial_backup_rate_limiter);
tli.bootstrap(
&mut shared_state,
&conf,
broker_active_set,
partial_backup_rate_limiter,
);
drop(shared_state);
Ok(tli)
}
// If we can't load a timeline, it's bad. Caller will figure it out.

View File

@@ -17,7 +17,9 @@ use std::time::Duration;
use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr;
use postgres_ffi::XLogFileName;
use postgres_ffi::{XLogSegNo, PG_TLI};
use remote_storage::{GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata};
use remote_storage::{
DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
};
use tokio::fs::File;
use tokio::select;
@@ -503,8 +505,12 @@ pub async fn read_object(
let cancel = CancellationToken::new();
let opts = DownloadOpts {
byte_start: std::ops::Bound::Included(offset),
..Default::default()
};
let download = storage
.download_storage_object(Some((offset, None)), file_path, &cancel)
.download(file_path, &opts, &cancel)
.await
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")

View File

@@ -13,7 +13,7 @@ use desim::{
node_os::NodeOs,
proto::{AnyMessage, NetEvent, NodeEvent},
};
use hyper0::Uri;
use http::Uri;
use safekeeper::{
safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION},
state::{TimelinePersistentState, TimelineState},

View File

@@ -1,9 +1,10 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging
from typing import Dict
import psycopg2
import psycopg2.extras
@@ -110,7 +111,7 @@ def main(args: argparse.Namespace):
output = args.output
percentile = args.percentile
res: Dict[str, float] = {}
res: dict[str, float] = {}
try:
logging.info("connecting to the database...")

View File

@@ -4,6 +4,9 @@
#
# This can be useful in disaster recovery.
#
from __future__ import annotations
import argparse
import psycopg2

View File

@@ -1,16 +1,21 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging
import os
from collections import defaultdict
from typing import Any, DefaultDict, Dict, Optional
from typing import TYPE_CHECKING
import psycopg2
import psycopg2.extras
import toml
if TYPE_CHECKING:
from typing import Any, Optional
FLAKY_TESTS_QUERY = """
SELECT
DISTINCT parent_suite, suite, name
@@ -33,7 +38,7 @@ def main(args: argparse.Namespace):
build_type = args.build_type
pg_version = args.pg_version
res: DefaultDict[str, DefaultDict[str, Dict[str, bool]]]
res: defaultdict[str, defaultdict[str, dict[str, bool]]]
res = defaultdict(lambda: defaultdict(dict))
try:
@@ -60,7 +65,7 @@ def main(args: argparse.Namespace):
pageserver_virtual_file_io_engine_parameter = ""
# re-use existing records of flaky tests from before parametrization by compaction_algorithm
def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict[str, Any]]:
def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]:
"""Duplicated from parametrize.py"""
toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM")
if toml_table is None:

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import asyncio
import json
@@ -5,11 +7,15 @@ import logging
import signal
import sys
from collections import defaultdict
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Tuple
from typing import TYPE_CHECKING
import aiohttp
if TYPE_CHECKING:
from typing import Any
class ClientException(Exception):
pass
@@ -89,7 +95,7 @@ class Client:
class Completed:
"""The status dict returned by the API"""
status: Dict[str, Any]
status: dict[str, Any]
sigint_received = asyncio.Event()
@@ -179,7 +185,7 @@ async def main_impl(args, report_out, client: Client):
"""
Returns OS exit status.
"""
tenant_and_timline_ids: List[Tuple[str, str]] = []
tenant_and_timline_ids: list[tuple[str, str]] = []
# fill tenant_and_timline_ids based on spec
for spec in args.what:
comps = spec.split(":")
@@ -215,14 +221,14 @@ async def main_impl(args, report_out, client: Client):
tenant_and_timline_ids = tmp
logging.info("create tasks and process them at specified concurrency")
task_q: asyncio.Queue[Tuple[str, Awaitable[Any]]] = asyncio.Queue()
task_q: asyncio.Queue[tuple[str, Awaitable[Any]]] = asyncio.Queue()
tasks = {
f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids
}
for task in tasks.items():
task_q.put_nowait(task)
result_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue()
result_q: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
taskq_handlers = []
for _ in range(0, args.concurrent_tasks):
taskq_handlers.append(taskq_handler(task_q, result_q))

View File

@@ -1,4 +1,7 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging

View File

@@ -1,5 +1,7 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import dataclasses
import json
@@ -11,7 +13,6 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Tuple
import backoff
import psycopg2
@@ -91,7 +92,7 @@ def create_table(cur):
cur.execute(CREATE_TABLE)
def parse_test_name(test_name: str) -> Tuple[str, int, str]:
def parse_test_name(test_name: str) -> tuple[str, int, str]:
build_type, pg_version = None, None
if match := TEST_NAME_RE.search(test_name):
found = match.groupdict()

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import logging
import os

View File

@@ -10,13 +10,16 @@ bench = []
[dependencies]
anyhow.workspace = true
async-stream.workspace = true
bytes.workspace = true
clap = { workspace = true, features = ["derive"] }
const_format.workspace = true
futures.workspace = true
futures-core.workspace = true
futures-util.workspace = true
humantime.workspace = true
hyper0 = { workspace = true, features = ["full"] }
hyper = { workspace = true, features = ["full"] }
http-body-util.workspace = true
hyper-util = "0.1"
once_cell.workspace = true
parking_lot.workspace = true
prost.workspace = true

View File

@@ -10,16 +10,15 @@
//!
//! Only safekeeper message is supported, but it is not hard to add something
//! else with generics.
extern crate hyper0 as hyper;
use clap::{command, Parser};
use futures_core::Stream;
use futures_util::StreamExt;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::header::CONTENT_TYPE;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, StatusCode};
use hyper::service::service_fn;
use hyper::{Method, StatusCode};
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::convert::Infallible;
@@ -27,9 +26,11 @@ use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::time;
use tonic::body::{self, empty_body, BoxBody};
use tonic::codegen::Service;
use tonic::transport::server::Connected;
use tonic::Code;
@@ -48,9 +49,7 @@ use storage_broker::proto::{
FilterTenantTimelineId, MessageType, SafekeeperDiscoveryRequest, SafekeeperDiscoveryResponse,
SafekeeperTimelineInfo, SubscribeByFilterRequest, SubscribeSafekeeperInfoRequest, TypedMessage,
};
use storage_broker::{
parse_proto_ttid, EitherBody, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR,
};
use storage_broker::{parse_proto_ttid, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR};
use utils::id::TenantTimelineId;
use utils::logging::{self, LogFormat};
use utils::sentry_init::init_sentry;
@@ -602,8 +601,8 @@ impl BrokerService for Broker {
// We serve only metrics and healthcheck through http1.
async fn http1_handler(
req: hyper::Request<hyper::body::Body>,
) -> Result<hyper::Response<Body>, Infallible> {
req: hyper::Request<Incoming>,
) -> Result<hyper::Response<BoxBody>, Infallible> {
let resp = match (req.method(), req.uri().path()) {
(&Method::GET, "/metrics") => {
let mut buffer = vec![];
@@ -614,16 +613,16 @@ async fn http1_handler(
hyper::Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.body(body::boxed(Full::new(bytes::Bytes::from(buffer))))
.unwrap()
}
(&Method::GET, "/status") => hyper::Response::builder()
.status(StatusCode::OK)
.body(Body::empty())
.body(empty_body())
.unwrap(),
_ => hyper::Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.body(empty_body())
.unwrap(),
};
Ok(resp)
@@ -665,52 +664,76 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
};
let storage_broker_server = BrokerServiceServer::new(storage_broker_impl);
info!("listening on {}", &args.listen_addr);
// grpc is served along with http1 for metrics on a single port, hence we
// don't use tonic's Server.
hyper::Server::bind(&args.listen_addr)
.http2_keep_alive_interval(Some(args.http2_keepalive_interval))
.serve(make_service_fn(move |conn: &AddrStream| {
let storage_broker_server_cloned = storage_broker_server.clone();
let connect_info = conn.connect_info();
async move {
Ok::<_, Infallible>(service_fn(move |mut req| {
// That's what tonic's MakeSvc.call does to pass conninfo to
// the request handler (and where its request.remote_addr()
// expects it to find).
req.extensions_mut().insert(connect_info.clone());
// Technically this second clone is not needed, but consume
// by async block is apparently unavoidable. BTW, error
// message is enigmatic, see
// https://github.com/rust-lang/rust/issues/68119
//
// We could get away without async block at all, but then we
// need to resort to futures::Either to merge the result,
// which doesn't caress an eye as well.
let mut storage_broker_server_svc = storage_broker_server_cloned.clone();
async move {
if req.headers().get("content-type").map(|x| x.as_bytes())
== Some(b"application/grpc")
{
let res_resp = storage_broker_server_svc.call(req).await;
// Grpc and http1 handlers have slightly different
// Response types: it is UnsyncBoxBody for the
// former one (not sure why) and plain hyper::Body
// for the latter. Both implement HttpBody though,
// and EitherBody is used to merge them.
res_resp.map(|resp| resp.map(EitherBody::Left))
} else {
let res_resp = http1_handler(req).await;
res_resp.map(|resp| resp.map(EitherBody::Right))
}
}
}))
let tcp_listener = TcpListener::bind(&args.listen_addr).await?;
info!("listening on {}", &args.listen_addr);
loop {
let (stream, addr) = match tcp_listener.accept().await {
Ok(v) => v,
Err(e) => {
info!("couldn't accept connection: {e}");
continue;
}
}))
.await?;
Ok(())
};
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
builder.http1().timer(TokioTimer::new());
builder
.http2()
.timer(TokioTimer::new())
.keep_alive_interval(Some(args.http2_keepalive_interval))
// This matches the tonic server default. It allows us to support production-like workloads.
.max_concurrent_streams(None);
let storage_broker_server_cloned = storage_broker_server.clone();
let connect_info = stream.connect_info();
let service_fn_ = async move {
service_fn(move |mut req| {
// That's what tonic's MakeSvc.call does to pass conninfo to
// the request handler (and where its request.remote_addr()
// expects it to find).
req.extensions_mut().insert(connect_info.clone());
// Technically this second clone is not needed, but consume
// by async block is apparently unavoidable. BTW, error
// message is enigmatic, see
// https://github.com/rust-lang/rust/issues/68119
//
// We could get away without async block at all, but then we
// need to resort to futures::Either to merge the result,
// which doesn't caress an eye as well.
let mut storage_broker_server_svc = storage_broker_server_cloned.clone();
async move {
if req.headers().get("content-type").map(|x| x.as_bytes())
== Some(b"application/grpc")
{
let res_resp = storage_broker_server_svc.call(req).await;
// Grpc and http1 handlers have slightly different
// Response types: it is UnsyncBoxBody for the
// former one (not sure why) and plain hyper::Body
// for the latter. Both implement HttpBody though,
// and `Either` is used to merge them.
res_resp.map(|resp| resp.map(http_body_util::Either::Left))
} else {
let res_resp = http1_handler(req).await;
res_resp.map(|resp| resp.map(http_body_util::Either::Right))
}
}
})
}
.await;
tokio::task::spawn(async move {
let res = builder
.serve_connection(TokioIo::new(stream), service_fn_)
.await;
if let Err(e) = res {
info!("error serving connection from {addr}: {e}");
}
});
}
}
#[cfg(test)]

Some files were not shown because too many files have changed in this diff Show More