Compare commits

...

53 Commits

Author SHA1 Message Date
Conrad Ludgate
4673dd6d29 set thp:never during build 2024-10-03 14:54:22 +01:00
Conrad Ludgate
37f5a6434b disable thp 2024-10-03 14:29:52 +01:00
Conrad Ludgate
6588edd693 log 2024-10-01 16:11:54 +01:00
Conrad Ludgate
973eb69cd3 try reduce memory usage of returned data 2024-10-01 13:17:17 +01:00
Conrad Ludgate
8bb2127a19 fix str 2024-10-01 13:06:24 +01:00
Conrad Ludgate
b5ad693a87 some start to using arenas 2024-10-01 13:06:24 +01:00
Conrad Ludgate
5a9138a764 support seeded deser 2024-10-01 12:52:45 +01:00
Conrad Ludgate
1466767571 share json parse fn 2024-10-01 12:49:03 +01:00
Conrad Ludgate
f11254f2c5 deduplicate even more 2024-10-01 12:44:41 +01:00
Conrad Ludgate
4529b463b5 array parsing by value 2024-10-01 12:44:41 +01:00
Conrad Ludgate
a8d4634191 move by value 2024-10-01 12:44:41 +01:00
Conrad Ludgate
53de382533 fold in the json params parsing 2024-10-01 12:44:41 +01:00
Conrad Ludgate
05f7fc4a06 split out 2024-10-01 12:44:39 +01:00
Conrad Ludgate
6946325596 remove duplication 2024-10-01 12:43:07 +01:00
Conrad Ludgate
b41070ba53 proxy: refactor untagged enum parsing with manually implemented deserialize 2024-10-01 12:43:07 +01:00
Conrad Ludgate
4391b25d01 proxy: ignore typ and use jwt.alg rather than jwk.alg (#9215)
Microsoft exposes JWKs without the alg header. It's only included on the
tokens. Not a problem.

Also noticed that wrt the `typ` header:
> It will typically not be used by applications when it is already known
that the object is a JWT. This parameter is ignored by JWT
implementations; any processing of this parameter is performed by the
JWT application.

Since we know we are expecting JWTs only, I've followed the guidance and
removed the validation.
2024-10-01 10:36:49 +01:00
John Spray
40b10b878a storage_scrubber: retry on index deletion failures (#9204)
## Problem

In automated tests running on AWS S3, we frequently see scrubber
failures when it can't delete an index.

`location_conf_churn`:

https://neon-github-public-dev.s3.amazonaws.com/reports/main/11076221056/index.html#/testresult/f89b1916b6a693e2

`scrubber_physical_gc`:

https://neon-github-public-dev.s3.amazonaws.com/reports/pr-9178/11074269153/index.html#/testresult/9885ed5aa0fe38b6

## Summary of changes

Wrap index deletion in a backoff::retry

---------

Co-authored-by: Arpad Müller <arpad-m@users.noreply.github.com>
2024-10-01 10:34:39 +01:00
David Gomes
d6c6b0a509 feat(compute): adds pg_session_jwt extension to compute image (#8888)
## Problem

We need the
[pg_session_jwt](https://github.com/neondatabase/pg_session_jwt/)
extension in the compute image. This PR adds it.

## Summary of changes

I added the `pg_session_jwt` extension in a very similar way to how the
pggraphql and pgtiktoken extensions were added (since they're all
written with pgrx). Then I tested this.

```
$ cd docker-compose/
$ PG_VERSION=16 TAG=10667533475 docker-compose up --build -d
$ psql postgresql://cloud_admin:cloud_admin@localhost:55433/postgres

cloud_admin@postgres=# create extension pg_session_jwt;
CREATE EXTENSION
Time: 43.048 ms

cloud_admin@postgres=# \df auth.*;
                              List of functions
┌────────┬──────────────────┬──────────────────┬─────────────────────┬──────┐
│ Schema │       Name       │ Result data type │ Argument data types │ Type │
├────────┼──────────────────┼──────────────────┼─────────────────────┼──────┤
│ auth   │ get              │ jsonb            │ s text              │ func │
│ auth   │ init             │ void             │ kid bigint, s jsonb │ func │
│ auth   │ jwt_session_init │ void             │ s text              │ func │
│ auth   │ user_id          │ text             │                     │ func │
└────────┴──────────────────┴──────────────────┴─────────────────────┴──────┘
(4 rows)

cloud_admin@postgres=# select auth.init(cast('1' as bigint), to_jsonb(TEXT '{ "kty": "EC", "kid": "571683be-33cf-4e67-bccc-8905c0ebb862", "crv": "P-521", "alg": "ES512", "x": "AM_GsnQvKML2yXdn_OsN8PdgO1Sf9XMXih5vQMKLmJkp-Iz_FFWJUt6uyR_qp4brr8Ji2kjGJgN4cQJpg2kskH7V", "y": "AZg-salw24lCmsBP-BCBa5jT6INkTwLtCOC7o0BIxDVvmIEH1-PQAJVYVJPTFvPMi_PLa0QlOm-ufJYkynwa2Mau" }'));
ERROR:  called `Result::unwrap()` on an `Err` value: Error("invalid type: string \"{ \\\"kty\\\": \\\"EC\\\", \\\"kid\\\": \\\"571683be-33cf-4e67-bccc-8905c0ebb862\\\", \\\"crv\\\": \\\"P-521\\\", \\\"alg\\\": \\\"ES512\\\", \\\"x\\\": \\\"AM_GsnQvKML2yXdn_OsN8PdgO1Sf9XMXih5vQMKLmJkp-Iz_FFWJUt6uyR_qp4brr8Ji2kjGJgN4cQJpg2kskH7V\\\", \\\"y\\\": \\\"AZg-salw24lCmsBP-BCBa5jT6INkTwLtCOC7o0BIxDVvmIEH1-PQAJVYVJPTFvPMi_PLa0QlOm-ufJYkynwa2Mau\\\" }\", expected struct JwkEcKey", line: 0, column: 0)
Time: 6.991 ms
```

## Checklist before requesting a review

- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Move the download location to a proper URL
2024-10-01 10:29:56 +01:00
John Spray
d515727e94 tests: make test_multi_attach more stable (#9202)
## Problem

`test_multi_attach` is sometimes failing with `invalid compute status
for configuration request: Configuration`. This is likely a result of
the test attempting to reconfigure the compute at the same time as the
storage controller is doing so.

This test was originally written before the storage controller existed,
and is not expecting anything else to be reconfiguring computes at the
same time.

## Summary of changes

- Configure the tenant into scheduling policy `Stop` in the storage
controller at the start of the test, so that it won't try to do anything
to the tenant while the test is running.
2024-10-01 10:15:18 +01:00
Folke Behrens
2e508b1ff9 Upgrade OpenTelemetry and other tracing crates (#9200)
* tracing-utils now returns a `Layer` impl. Removes the need for crates
to
  import OTel crates.
* Drop the /v1/traces URI check. Verified that the code does the right
thing.
* Leave a TODO to hook in an error handler for OTel to log errors to
when it
  assumes the regular pipeline cannot be used/is broken.
2024-10-01 11:02:54 +02:00
John Spray
651ae44569 storage controller: drop out of blocking compute notification loop if migration origin becomes unavailable (#9147)
## Problem

The live migration code waits forever for the compute notification hook,
on the basis that until it succeeds, the compute is probably using the
old location and we shouldn't detach it.

However, if a pageserver stops or restarts in the background, then this
original location might no longer be available, so there is no point
waiting. Waiting is also actively harmful, because it prevents other
reconciliations happening for the tenant shard, such as during an
upgrade where a stuck "drain" migration might prevent the later "fill"
migration from moving the shard back to its original location.

## Summary of changes

- Refactor the notification wait loop into a function
- Add a checks during the loop, for the origin node's cancellation token
and an explicit HTTP request to the origin node to confirm the shard is
still attached there.

Closes: https://github.com/neondatabase/neon/issues/8901
2024-10-01 07:57:22 +00:00
Heikki Linnakangas
65bda19051 Remove unnecessary dev package from compute image (#9210)
libcurl4-openssl-dev is needed to build pgxn/, but libcurl4 is enough at
runtime.
2024-10-01 01:07:43 +03:00
Conrad Ludgate
94a5ca2817 proxy: auth broker (#8855)
Opens http2 connection to local-proxy and forwards requests over with
all headers and body

closes https://github.com/neondatabase/cloud/issues/16039
2024-09-30 20:43:45 +01:00
Arthur Petukhovsky
c07cea80bd Bump vm-builder v0.29.3 -> v0.35.0 (#9208)
We haven't updated it for a while. Now I need the update to add quotas
support to compute images
(https://github.com/neondatabase/cloud/issues/13127).

Previous update: https://github.com/neondatabase/neon/pull/7849
2024-09-30 19:18:42 +01:00
Conrad Ludgate
a2e2362ee9 add proxy-protocol header disable option (#9203)
resolves https://github.com/neondatabase/cloud/issues/18026
2024-09-30 18:11:50 +00:00
Heikki Linnakangas
0a567acdb9 tests: Move comment to more appropriate place
There is no 'pg_bin' in NeonEnv.
2024-09-30 17:56:43 +03:00
Heikki Linnakangas
69ea2776e9 tests: Remove creation of extra timelines in some tests
neon_cli.create_tenant() creates a new tenant *and* a timeline on the
tenant, with name "main". In most tests, there's no need to create
another timeline on the same tenant.

There are some more tests that do that, but in the remaining cases, I
wasn't be 100% if the presence of extra root timelines affect what the
tests test, so I left them alone.
2024-09-30 17:56:40 +03:00
Heikki Linnakangas
4dc9cb7cf9 tests: Remove some spurious list_timelines calls
These calls seem really out of place. We know what the initial tenant
and branch are in these tests, just like in all other tests.
2024-09-30 17:56:37 +03:00
John Spray
7424e7269c tests: longer timeout in test_delete_timeline_client_hangup (#9161)
## Problem

This test waits for a request to finish, and then expects deletion to
complete almost immediately. The request completes, but it's a 202, the
timeline is still deleting in the background: we need to be more
patient.

## Summary of changes

- Adjust iterations from 2 to 10 when waiting for deletion
2024-09-30 15:46:07 +01:00
a-masterov
5dc68e4e6a test_compatibility: fix the regexes detecting the version (#9205)
## Problem
The Neon components, built locally and by the GitHub workflow have
slightly different version prefixes (git: vs git-env:)
This does not allow running tests against local builds correctly.

## Summary of changes
The regular expressions were changed to work with both
prefixes.
2024-09-30 16:37:14 +02:00
John Spray
7cfd116856 pageserver: refactor immediate_gc into TenantManager (#9183)
## Problem

Legacy functions that were called as `mgr::` and relied on the static
TENANTS, see #5796

## Summary of changes

- Move the last stray function (immediate_gc) into TenantManager

Closes: https://github.com/neondatabase/neon/issues/5796
2024-09-30 09:27:28 +01:00
Heikki Linnakangas
d696c41807 Bump default neon extension version to 1.5 (#9188)
Commit 263dfba6ee introduced neon extension version 1.5, which included
some new functions and views for metrics. It didn't bump the default
neon extension number yet, so that we could still safely roll back to
the old binary if necessary. This bumps the default version.
2024-09-30 09:20:52 +03:00
Alexander Bayandin
3c72192065 CI(benchmarking): fix setting LD_LIBRARY_PATH (#9191)
## Problem

`pgbench-pgvector` job from Nightly Benchmarks fails with the error:

```
/__w/_temp/f45bc2eb-4c4c-4f0a-8030-99079303fa65.sh: line 17: LD_LIBRARY_PATH: unbound variable
```

## Summary of changes
- Fix `LD_LIBRARY_PATH: unbound variable` error in benchmarks
2024-09-29 22:27:53 +00:00
Alexander Bayandin
d2d9921761 CI(benchmarking): fix Nightly Benchmarks (#9178)
## Problem

Nightly Benchmarks have been broken for some time due to various
reasons, this PR fixes it

## Summary of changes
- Pull `build-tools` image from dockerhub for `benchmarking` workflow
- Use `aws-actions/configure-aws-credentials` to upload/download
artifacts from S3
- Fix Postgres 16 installation (for pgbench)
2024-09-28 02:44:22 +01:00
Arthur Petukhovsky
ba498a630a Set disk quotas on bind in compute_ctl (#8936)
Part of https://github.com/neondatabase/cloud/issues/13127. Resolves
#9153

What changed in this PR:
1. Adds `ComputeSpec.disk_quota_bytes: Option<u64>`
2. Adds new arg to compute_ctl: `--set-disk-quota-for-fs <mountpoint>`
3. Implements running `/neonvm/bin/set-disk-quota` with the right value
if both cmdline arg AND field in the spec are specified
4. Patches `/etc/sudoers.d` to allow `compute_ctl` to set quota with
sudo

This PR is very similar to the swap support added earlier, you can take
a look at it as prior art: #7434

In theory, it can be implemented outside of compute_ctl when we will
have a separate neonvm daemon, but we are not there yet. Current
implementation is the simplest possible to unblock computes with larger
disks.

All code related to usage of `/neonvm/bin/set-disk-quota` is located in
`disk_quota.rs`. We need to call this script with the following
arguments: `/neonvm/bin/set-disk-quota {size_kb} {mountpoint}`. Quotas
are set on the filesystem level, so we need to provide path to the
directory that filesystem was mounted to.

I tested this change locally with
https://github.com/neondatabase/cloud/pull/17270. It should be safe to
merge, because this feature is gated by both cmdline arg and field in
the spec. If control-plane doesn't set values in both places,
compute_ctl won't be affected by this change.
2024-09-27 20:52:22 +01:00
Heikki Linnakangas
e989a5e4a2 neon_local: Use clap derive macros to parse the CLI args (#9103)
This is easier to work with.
2024-09-27 22:08:46 +03:00
Alex Chi Z.
cde1654d7b fix(pageserver): abort process if fsync fails (#9108)
close https://github.com/neondatabase/neon/issues/8140

The original issue is rather vague on what we should do. After
discussion w/ @problame we decided to narrow down the problems we want
to solve in that issue.

* read path -- do not panic for now.
* write path -- panic only on write errors (i.e., device error, fsync
error), but not on no-space for now.

The guideline is that if the pageserver behavior could lead to violation
of persistent constraints (i.e., return an operation as successful but
not actually persisting things), we should panic. Fsync is the place
where both of us agree that we should panic, because if fsync fails, the
kernel will mark dirty pages as clean, and the next fsync will not
necessarily return false. This would make the storage client assume the
operation is successful.

## Summary of changes

Make fsync panic on fatal errors.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-09-27 19:58:50 +01:00
Heikki Linnakangas
cf6a776fcf tests: Reduce the # of iterations in safekeeper::test_random_schedules (#9182)
To make it faster. On my laptop, it takes about 30 before this commit.
In the arm64 debug variant in CI, it takes about 120 s. Reduce it by
factor of 4.
2024-09-27 16:25:35 +00:00
Matthias van de Meent
5c5871111a WalProposer: Read WAL directly from WAL buffers in PG17 (#9171)
This reduces the overhead of the WalProposer when it is not being
throttled by SK WAL acceptance rate
2024-09-27 17:47:05 +02:00
Yuchen Liang
d56c4e7a38 pageserver: remove AdjacentVectoredReadBuilder and bump minmimum io_buffer_alignment to 512 (#9175)
Part of #8130

## Problem

After deploying https://github.com/neondatabase/infra/pull/1927, we
shipped `io_buffer_alignment=512` to all prod region. The
`AdjacentVectoredReadBuilder` code path is no longer taken and we are
running pageserver unit tests 6 times in the CI. Removing it would
reduce the test duration by 30-60s.

## Summary of changes

- Remove `AdjacentVectoredReadBuilder` code.
- Bump the minimum `io_buffer_alignment` requirement to at least 512
bytes.
- Use default `io_buffer_alignment` for Rust unit tests.

Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-09-27 16:41:42 +01:00
Conrad Ludgate
43b2445d0b proxy: add jwks endpoint to control plane and mock providers (#9165) 2024-09-27 16:08:43 +01:00
Yuchen Liang
42ef08db47 fix(pageserver): LSN lease edge cases around restarts/migrations (#9055)
Part of #7497, closes #8817.

## Problem

See #8817. 

## Summary of changes

**compute_ctl**

- Renew lsn lease as soon as `/configure` updates pageserver_connstr,
use `state_changed` Condvar for synchronization.

**pageserver**

As mentioned in
https://github.com/neondatabase/neon/issues/8817#issuecomment-2315768076,
we still want some permanent error reported if a lease cannot be
granted. By considering attachment mode and the added
`lsn_lease_deadline` when processing lease requests, we can also bound
the case of bad requests to a very short period after migration/restart.

- Refactor https://github.com/neondatabase/neon/pull/9024 and move
`lsn_lease_deadline` to `AttachedTenantConf` so timeline can easily
access it.
- Have separate HTTP `init_lsn_lease` and  libpq `renew_lsn_lease` API.
  - Always do LSN verification for the initial HTTP lease request.
- LSN verification for the renewal is **still done** when tenants are
not in `AttachedSingle` and we have pass the `lsn_lease_deadline`, which
give plenty of time for compute to renew the lease.
 
**neon_local**

- add and call `timeline_init_lsn_lease` mgmt_api at static endpoint
start. The initial lsn lease http request is sent when we run `cargo
neon endpoint start <static endpoint>`.


## Testing

- Extend `test_readonly_node_gc` to do pageserver restarts and
migration.

## Future Work

- The control plane should make the initial lease request through HTTP
when creating a static endpoint. This is currently only done in
`neon_local`.

Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-09-27 09:56:52 -04:00
Tristan Partin
fc962c9605 Use long options when calling initdb
Verbosity in this case is good when reading the code. Short options are
better when operating in an interactive shell.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-09-27 08:22:16 -05:00
Heikki Linnakangas
357fa070a3 Add gdb to build-tools (#9125)
So that compute_ctl can use it to print backtrace on core dumps

See issue #2800.
2024-09-27 15:36:24 +03:00
Heikki Linnakangas
02cdd37b56 Dump backtrace if a core dump is called just "core" (#9125)
I hope this lets us capture backtraces in CI. At least it makes it
work on my laptop, which is valuable even if we need to do more for
CI.

See issue #2800.
2024-09-27 15:36:24 +03:00
Vlad Lazar
fa354a65ab libs: improve logging on PG connection errors (#9130)
## Problem
We get some unexpected errors, but don't know who they're happening for.

## Summary of change
Add tenant id and peer address to PG connection error logs.

Related https://github.com/neondatabase/cloud/issues/17336
2024-09-27 12:36:43 +01:00
Arseny Sher
40f7930a7d safekeeper: skip syncfs on start if --no-sync is specified. (#9166)
https://neondb.slack.com/archives/C059ZC138NR/p1727350911890989?thread_ts=1727350211.370869&cid=C059ZC138NR
2024-09-27 09:59:38 +03:00
Conrad Ludgate
ec07a1ecc9 proxy: make local-proxy config by signal with PID, refine JWKS apis with role caching (#9164) 2024-09-26 19:01:48 +01:00
Arseny Sher
c4cdfe66ac Fix flakiness of test_timeline_copy.
Timeline might be not initialized when timeline_start_lsn is
queried. Spotted by CI.
2024-09-26 19:01:45 +03:00
Alex Chi Z.
42e19e952f fix(pageserver): categorize client error in basebackup metrics (#9110)
We separated client error from basebackup error log lines in
https://github.com/neondatabase/neon/pull/7523, but we didn't do
anything for the metrics. In this patch, we fixed it.

ref https://github.com/neondatabase/neon/issues/8970

## Summary of changes

We use the same criteria as in `log_query_error` producing an info line
(instead of error) for the metrics. We added a `client_error` category
for the basebackup query time metrics.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-09-26 11:38:19 -04:00
John Spray
3d255d601b pageserver: rename control plane client & chunk validation requests (#8997)
## Problem

- In https://github.com/neondatabase/neon/pull/8784, the validate
controller API is modified to check generations directly in the
database. It batches tenants into separate queries to avoid generating a
huge statement, but
- While updating this, I realized that "control_plane_client" is a kind
of confusing name for the client code now that it primarily talks to the
storage controller (the case of talking to the control plane will go
away in a few months).

## Summary of changes

- Big rename to "ControllerUpcallClient" -- this reflects the storage
controller's api naming, where the paths used by the pageserver are in
`/upcall/`
- When sending validate requests, break them up into chunks so that we
avoid possible edge cases of generating any HTTP requests that require
database I/O across many thousands of tenants.

This PR mixes a functional change with a refactor, but the commits are
cleanly separated -- only the last commit is a functional change.

---------

Co-authored-by: Christian Schwarz <christian@neon.tech>
2024-09-26 16:06:34 +01:00
Arthur Petukhovsky
80e974d05b fix(compute_ctl): race condition in configurator (#9162)
There was a tricky race condition in compute_ctl, that sometimes makes
configurator skip updates. It makes a deadlock because:
- control-plane cannot configure compute, because it's in
ConfigurationPending state
- compute_ctl doesn't do any reconfiguration because
`configurator_main_loop` missed notification for it

Full sequence that reproduces the issue:
1. `start_compute` finishes works and changes status
`self.set_status(ComputeStatus::Running);`
2. configurator received update about `Running` state and dropped the
mutex lock in the iteration
3. `/configure` request was triggered at the same time as step 1, and
got the mutex lock
4. same `/configure` request set the spec and updated the state to
`ConfigurationPending`, also sent a notification
5. next iteration in configurator got the mutex lock, but missed the
notification

There are more details in this slack thread:
https://neondb.slack.com/archives/C03438W3FLZ/p1727281028478689?thread_ts=1727261220.483799&cid=C03438W3FLZ

---------

Co-authored-by: Alexey Kondratov <kondratov.aleksey@gmail.com>
2024-09-26 15:42:17 +01:00
Alexander Bayandin
7fdf1ab5b6 CI: run compatibility tests on Postgres 17 (#9145)
## Problem

The latest storage release has generated artifacts for Postgres 17,
so we can enable compatibility tests this version

## Summary of changes
- Unskip `test_backward_compatibility` / `test_forward_compatibility` on
Postgres 17
2024-09-26 15:17:01 +01:00
98 changed files with 4405 additions and 2153 deletions

View File

@@ -3,19 +3,23 @@ name: Prepare benchmarking databases by restoring dumps
on:
workflow_call:
# no inputs needed
defaults:
run:
shell: bash -euxo pipefail {0}
jobs:
setup-databases:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
matrix:
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
database: [ clickbench, tpch, userexample ]
env:
LD_LIBRARY_PATH: /tmp/neon/pg_install/v16/lib
PLATFORM: ${{ matrix.platform }}
@@ -23,7 +27,10 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
@@ -32,13 +39,13 @@ jobs:
run: |
case "${PLATFORM}" in
neon)
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
;;
aws-rds-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
;;
aws-aurora-serverless-v2-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
;;
*)
echo >&2 "Unknown PLATFORM=${PLATFORM}"
@@ -46,10 +53,17 @@ jobs:
;;
esac
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -57,23 +71,23 @@ jobs:
path: /tmp/neon/
prefix: latest
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
- name: Create benchmark_restore_status table if it does not exist
env:
BENCHMARK_CONNSTR: ${{ steps.set-up-prep-connstr.outputs.connstr }}
DATABASE_NAME: ${{ matrix.database }}
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# we use an advisory lock
run: |
${PG_BINARIES}/psql "${{ env.BENCHMARK_CONNSTR }}" -c "
SELECT pg_advisory_lock(4711);
SELECT pg_advisory_lock(4711);
CREATE TABLE IF NOT EXISTS benchmark_restore_status (
databasename text primary key,
restore_done boolean
);
SELECT pg_advisory_unlock(4711);
"
- name: Check if restore is already done
id: check-restore-done
env:
@@ -107,7 +121,7 @@ jobs:
DATABASE_NAME: ${{ matrix.database }}
run: |
mkdir -p /tmp/dumps
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
- name: Replace database name in connection string
if: steps.check-restore-done.outputs.skip != 'true'
@@ -126,17 +140,17 @@ jobs:
else
new_connstr="${base_connstr}/${DATABASE_NAME}"
fi
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
- name: Restore dump
if: steps.check-restore-done.outputs.skip != 'true'
env:
DATABASE_NAME: ${{ matrix.database }}
DATABASE_CONNSTR: ${{ steps.replace-dbname.outputs.database_connstr }}
# the following works only with larger computes:
# the following works only with larger computes:
# PGOPTIONS: "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7"
# we add the || true because:
# the dumps were created with Neon and contain neon extensions that are not
# the dumps were created with Neon and contain neon extensions that are not
# available in RDS, so we will always report an error, but we can ignore it
run: |
${PG_BINARIES}/pg_restore --clean --if-exists --no-owner --jobs=4 \

View File

@@ -236,9 +236,7 @@ jobs:
# run pageserver tests with different settings
for io_engine in std-fs tokio-epoll-uring ; do
for io_buffer_alignment in 0 1 512 ; do
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT=$io_buffer_alignment ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
done
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
done
# Run separate tests for real S3

View File

@@ -12,7 +12,6 @@ on:
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 3 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually
inputs:
region_id:
@@ -59,7 +58,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # Required for OIDC authentication in azure runners
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
matrix:
@@ -68,12 +67,10 @@ jobs:
PLATFORM: "neon-staging"
region_id: ${{ github.event.inputs.region_id || 'aws-us-east-2' }}
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- DEFAULT_PG_VERSION: 16
PLATFORM: "azure-staging"
region_id: 'azure-eastus2'
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
@@ -86,7 +83,10 @@ jobs:
runs-on: ${{ matrix.RUNNER }}
container:
image: ${{ matrix.IMAGE }}
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
@@ -164,6 +164,10 @@ jobs:
replication-tests:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 16
@@ -174,12 +178,21 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
@@ -267,7 +280,7 @@ jobs:
region_id_default=${{ env.DEFAULT_REGION_ID }}
runner_default='["self-hosted", "us-east-2", "x64"]'
runner_azure='["self-hosted", "eastus2", "x64"]'
image_default="369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned"
image_default="neondatabase/build-tools:pinned"
matrix='{
"pg_version" : [
16
@@ -344,7 +357,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # Required for OIDC authentication in azure runners
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
@@ -371,7 +384,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials # necessary on Azure runners
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -492,17 +505,15 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # Required for OIDC authentication in azure runners
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
matrix:
include:
- PLATFORM: "neonvm-captest-pgvector"
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- PLATFORM: "azure-captest-pgvector"
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "15m"
@@ -511,13 +522,16 @@ jobs:
DEFAULT_PG_VERSION: 16
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
LD_LIBRARY_PATH: /home/nonroot/pg/usr/lib/x86_64-linux-gnu
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: ${{ matrix.PLATFORM }}
runs-on: ${{ matrix.RUNNER }}
container:
image: ${{ matrix.IMAGE }}
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
@@ -527,17 +541,26 @@ jobs:
# instead of using Neon artifacts containing pgbench
- name: Install postgresql-16 where pytest expects it
run: |
# Just to make it easier to test things locally on macOS (with arm64)
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
cd /home/nonroot
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/libpq5_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110%2B1_amd64.deb
dpkg -x libpq5_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+1_amd64.deb pg
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.0-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110+2_${arch}.deb"
dpkg -x libpq5_17.0-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+2_${arch}.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb pg
mkdir -p /tmp/neon/pg_install/v16/bin
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/x86_64-linux-gnu /tmp/neon/pg_install/v16/lib
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu /tmp/neon/pg_install/v16/lib
LD_LIBRARY_PATH="/home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu:${LD_LIBRARY_PATH:-}"
export LD_LIBRARY_PATH
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> ${GITHUB_ENV}
/tmp/neon/pg_install/v16/bin/pgbench --version
/tmp/neon/pg_install/v16/bin/psql --version
@@ -559,7 +582,7 @@ jobs:
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- name: Configure AWS credentials # necessary on Azure runners to read/write from/to S3
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -620,6 +643,10 @@ jobs:
# *_CLICKBENCH_CONNSTR: Genuine ClickBench DB with ~100M rows
# *_CLICKBENCH_10M_CONNSTR: DB with the first 10M rows of ClickBench DB
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, pgbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -638,12 +665,22 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -714,6 +751,10 @@ jobs:
#
# *_TPCH_S10_CONNSTR: DB generated with scale factor 10 (~10 GB)
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, clickbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -731,12 +772,22 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -806,6 +857,10 @@ jobs:
user-examples-compare:
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, tpch-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -822,12 +877,22 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:

View File

@@ -773,7 +773,7 @@ jobs:
matrix:
version: [ v14, v15, v16, v17 ]
env:
VM_BUILDER_VERSION: v0.29.3
VM_BUILDER_VERSION: v0.35.0
steps:
- uses: actions/checkout@v4
@@ -1190,10 +1190,9 @@ jobs:
files_to_promote+=("s3://${BUCKET}/${s3_key}")
# TODO Add v17
for pg_version in v14 v15 v16; do
for pg_version in v14 v15 v16 v17; do
# We run less tests for debug builds, so we don't need to promote them
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v16" ] ; }; then
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v17" ] ; }; then
continue
fi

272
Cargo.lock generated
View File

@@ -90,9 +90,9 @@ dependencies = [
[[package]]
name = "anstyle"
version = "1.0.0"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
[[package]]
name = "anstyle-parse"
@@ -1223,6 +1223,7 @@ dependencies = [
"notify",
"num_cpus",
"opentelemetry",
"opentelemetry_sdk",
"postgres",
"regex",
"remote_storage",
@@ -1321,6 +1322,7 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"futures",
"humantime",
"humantime-serde",
"hyper 0.14.30",
@@ -1875,9 +1877,9 @@ dependencies = [
[[package]]
name = "env_logger"
version = "0.10.0"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0"
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
dependencies = [
"humantime",
"is-terminal",
@@ -3367,102 +3369,82 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "opentelemetry"
version = "0.20.0"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
]
[[package]]
name = "opentelemetry-http"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b"
dependencies = [
"async-trait",
"bytes",
"http 0.2.9",
"opentelemetry_api",
"reqwest 0.11.19",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
dependencies = [
"async-trait",
"futures-core",
"http 0.2.9",
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry-semantic-conventions",
"opentelemetry_api",
"opentelemetry_sdk",
"prost",
"reqwest 0.11.19",
"thiserror",
"tokio",
"tonic",
]
[[package]]
name = "opentelemetry-proto"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
"prost",
"tonic",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269"
dependencies = [
"opentelemetry",
]
[[package]]
name = "opentelemetry_api"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b"
dependencies = [
"futures-channel",
"futures-util",
"indexmap 1.9.3",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"urlencoding",
]
[[package]]
name = "opentelemetry_sdk"
version = "0.20.0"
name = "opentelemetry-http"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026"
checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab"
dependencies = [
"async-trait",
"bytes",
"http 1.1.0",
"opentelemetry",
"reqwest 0.12.4",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b925a602ffb916fb7421276b86756027b37ee708f9dce2dbdcc51739f07e727"
dependencies = [
"async-trait",
"futures-core",
"http 1.1.0",
"opentelemetry",
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry_sdk",
"prost 0.13.3",
"reqwest 0.12.4",
"thiserror",
]
[[package]]
name = "opentelemetry-proto"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9"
dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"prost 0.13.3",
"tonic 0.12.2",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cefe0543875379e47eb5f1e68ff83f45cc41366a92dfd0d073d513bf68e9a05"
[[package]]
name = "opentelemetry_sdk"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry_api",
"ordered-float 3.9.2",
"opentelemetry",
"percent-encoding",
"rand 0.8.5",
"regex",
"serde_json",
"thiserror",
"tokio",
@@ -3478,15 +3460,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "ordered-float"
version = "3.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc"
dependencies = [
"num-traits",
]
[[package]]
name = "ordered-multimap"
version = "0.7.3"
@@ -4229,7 +4202,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive",
"prost-derive 0.11.9",
]
[[package]]
name = "prost"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f"
dependencies = [
"bytes",
"prost-derive 0.13.3",
]
[[package]]
@@ -4246,7 +4229,7 @@ dependencies = [
"multimap",
"petgraph",
"prettyplease 0.1.25",
"prost",
"prost 0.11.9",
"prost-types",
"regex",
"syn 1.0.109",
@@ -4267,13 +4250,26 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "prost-derive"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "prost-types"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost",
"prost 0.11.9",
]
[[package]]
@@ -4296,6 +4292,7 @@ dependencies = [
"camino-tempfile",
"chrono",
"clap",
"compute_api",
"consumption_metrics",
"dashmap",
"ecdsa 0.16.9",
@@ -4369,7 +4366,6 @@ dependencies = [
"tokio-tungstenite",
"tokio-util",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"tracing-utils",
"try-lock",
@@ -4814,9 +4810,9 @@ dependencies = [
[[package]]
name = "reqwest-tracing"
version = "0.5.0"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b253954a1979e02eabccd7e9c3d61d8f86576108baa160775e7f160bb4e800a3"
checksum = "bfdd9bfa64c72233d8dd99ab7883efcdefe9e16d46488ecb9228b71a2e2ceb45"
dependencies = [
"anyhow",
"async-trait",
@@ -5701,9 +5697,9 @@ dependencies = [
"metrics",
"once_cell",
"parking_lot 0.12.1",
"prost",
"prost 0.11.9",
"tokio",
"tonic",
"tonic 0.9.2",
"tonic-build",
"tracing",
"utils",
@@ -6027,7 +6023,7 @@ checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09"
dependencies = [
"byteorder",
"integer-encoding",
"ordered-float 2.10.1",
"ordered-float",
]
[[package]]
@@ -6129,9 +6125,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.37.0"
version = "1.38.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df"
dependencies = [
"backtrace",
"bytes",
@@ -6173,9 +6169,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.2.0"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a"
dependencies = [
"proc-macro2",
"quote",
@@ -6350,7 +6346,7 @@ dependencies = [
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost",
"prost 0.11.9",
"rustls-native-certs 0.6.2",
"rustls-pemfile 1.0.2",
"tokio",
@@ -6362,6 +6358,27 @@ dependencies = [
"tracing",
]
[[package]]
name = "tonic"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6f6ba989e4b2c58ae83d862d3a3e27690b6e3ae630d0deb59f3697f32aa88ad"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"percent-encoding",
"pin-project",
"prost 0.13.3",
"tokio-stream",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.9.2"
@@ -6409,11 +6426,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.37"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -6433,9 +6449,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.24"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
@@ -6444,9 +6460,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.31"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
@@ -6464,21 +6480,22 @@ dependencies = [
[[package]]
name = "tracing-log"
version = "0.1.3"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"lazy_static",
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.21.0"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry",
"opentelemetry_sdk",
@@ -6487,6 +6504,7 @@ dependencies = [
"tracing-core",
"tracing-log",
"tracing-subscriber",
"web-time",
]
[[package]]
@@ -6501,9 +6519,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
version = "0.3.17"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"matchers",
"once_cell",
@@ -6527,6 +6545,7 @@ dependencies = [
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
"opentelemetry_sdk",
"tokio",
"tracing",
"tracing-opentelemetry",
@@ -6982,6 +7001,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.25.2"
@@ -7251,7 +7280,6 @@ dependencies = [
"chrono",
"clap",
"clap_builder",
"crossbeam-utils",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
@@ -7284,13 +7312,12 @@ dependencies = [
"once_cell",
"parquet",
"proc-macro2",
"prost",
"prost 0.11.9",
"quote",
"rand 0.8.5",
"regex",
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
"reqwest 0.11.19",
"reqwest 0.12.4",
"rustls 0.21.11",
"scopeguard",
@@ -7311,12 +7338,9 @@ dependencies = [
"tokio-rustls 0.24.0",
"tokio-util",
"toml_edit",
"tonic",
"tower",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",

View File

@@ -116,9 +116,10 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.20.0"
opentelemetry-otlp = { version = "0.13.0", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.12.0"
opentelemetry = "0.24"
opentelemetry_sdk = "0.24"
opentelemetry-otlp = { version = "0.17", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.16"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -131,7 +132,7 @@ rand = "0.8"
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_20"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_24"] }
reqwest-middleware = "0.3.0"
reqwest-retry = "0.5"
routerify = "3"
@@ -177,8 +178,8 @@ toml_edit = "0.22"
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
tower-service = "0.3.2"
tracing = "0.1"
tracing-error = "0.2.0"
tracing-opentelemetry = "0.21.0"
tracing-error = "0.2"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"
twox-hash = { version = "1.6.3", default-features = false }

View File

@@ -42,6 +42,7 @@ COPY --from=pg-build /home/nonroot/pg_install/v17/lib pg_i
COPY --chown=nonroot . .
ARG ADDITIONAL_RUSTFLAGS
ENV _RJEM_MALLOC_CONF="thp:never"
RUN set -e \
&& PQ_LIB_DIR=$(pwd)/pg_install/v${STABLE_PG_VERSION}/lib RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment ${ADDITIONAL_RUSTFLAGS}" cargo build \
--bin pg_sni_router \

View File

@@ -13,6 +13,9 @@ RUN useradd -ms /bin/bash nonroot -b /home
SHELL ["/bin/bash", "-c"]
# System deps
#
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
&& apt update \
&& apt install -y \
@@ -24,6 +27,7 @@ RUN set -e \
cmake \
curl \
flex \
gdb \
git \
gnupg \
gzip \

View File

@@ -871,6 +871,28 @@ RUN case "${PG_VERSION}" in "v17") \
cargo pgrx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
#########################################################################################
#
# Layer "pg-session-jwt-build"
# Compile "pg_session_jwt" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-session-jwt-build
ARG PG_VERSION
RUN case "${PG_VERSION}" in "v17") \
echo "pg_session_jwt does not yet have a release that supports pg17" && exit 0;; \
esac && \
wget https://github.com/neondatabase/pg_session_jwt/archive/ff0a72440e8ff584dab24b3f9b7c00c56c660b8e.tar.gz -O pg_session_jwt.tar.gz && \
echo "1fbb2b5a339263bcf6daa847fad8bccbc0b451cea6a62e6d3bf232b0087f05cb pg_session_jwt.tar.gz" | sha256sum --check && \
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
cargo pgrx install --release
# it's needed to enable extension because it uses untrusted C language
# sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_session_jwt.control && \
# echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_session_jwt.control
#########################################################################################
#
# Layer "wal2json-build"
@@ -967,6 +989,7 @@ COPY --from=timescaledb-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hint-plan-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-cron-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-pgx-ulid-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-session-jwt-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
@@ -1258,7 +1281,7 @@ RUN apt update && \
libxml2 \
libxslt1.1 \
libzstd1 \
libcurl4-openssl-dev \
libcurl4 \
locales \
procps \
ca-certificates \

View File

@@ -11,6 +11,10 @@ commands:
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/resize-swap'
- name: chmod-set-disk-quota
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/set-disk-quota'
- name: pgbouncer
user: postgres
sysvInitAction: respawn
@@ -30,11 +34,12 @@ commands:
shutdownHook: |
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
files:
- filename: compute_ctl-resize-swap
- filename: compute_ctl-sudoers
content: |
# Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap
# as root without requiring entering a password (NOPASSWD), regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap
# and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD),
# regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
@@ -100,7 +105,7 @@ merge: |
&& apt install --no-install-recommends -y \
sudo \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
COPY compute_ctl-resize-swap /etc/sudoers.d/compute_ctl-resize-swap
COPY compute_ctl-sudoers /etc/sudoers.d/compute_ctl-sudoers
COPY cgconfig.conf /etc/cgconfig.conf

View File

@@ -21,6 +21,7 @@ nix.workspace = true
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
opentelemetry_sdk.workspace = true
postgres.workspace = true
regex.workspace = true
serde_json.workspace = true

View File

@@ -44,6 +44,7 @@ use std::{thread, time::Duration};
use anyhow::{Context, Result};
use chrono::Utc;
use clap::Arg;
use compute_tools::disk_quota::set_disk_quota;
use compute_tools::lsn_lease::launch_lsn_lease_bg_task_for_static;
use signal_hook::consts::{SIGQUIT, SIGTERM};
use signal_hook::{consts::SIGINT, iterator::Signals};
@@ -151,6 +152,7 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
let spec_json = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
let resize_swap_on_bind = matches.get_flag("resize-swap-on-bind");
let set_disk_quota_for_fs = matches.get_one::<String>("set-disk-quota-for-fs");
Ok(ProcessCliResult {
connstr,
@@ -161,6 +163,7 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
spec_json,
spec_path,
resize_swap_on_bind,
set_disk_quota_for_fs,
})
}
@@ -173,6 +176,7 @@ struct ProcessCliResult<'clap> {
spec_json: Option<&'clap String>,
spec_path: Option<&'clap String>,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<&'clap String>,
}
fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
@@ -214,7 +218,7 @@ fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
}
if !startup_tracing_carrier.is_empty() {
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::propagation::TraceContextPropagator;
let guard = TraceContextPropagator::new()
.extract(&startup_tracing_carrier)
.attach();
@@ -293,6 +297,7 @@ fn wait_spec(
pgbin,
ext_remote_storage,
resize_swap_on_bind,
set_disk_quota_for_fs,
http_port,
..
}: ProcessCliResult,
@@ -373,6 +378,7 @@ fn wait_spec(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs: set_disk_quota_for_fs.cloned(),
})
}
@@ -381,6 +387,7 @@ struct WaitSpecResult {
// passed through from ProcessCliResult
http_port: u16,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<String>,
}
fn start_postgres(
@@ -390,6 +397,7 @@ fn start_postgres(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs,
}: WaitSpecResult,
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
// We got all we need, update the state.
@@ -403,6 +411,7 @@ fn start_postgres(
);
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -422,8 +431,8 @@ fn start_postgres(
// OOM-killed during startup because swap wasn't available yet.
match resize_swap(size_bytes) {
Ok(()) => {
let size_gib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_gib, "resized swap");
let size_mib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_mib, "resized swap");
}
Err(err) => {
let err = err.context("failed to resize swap");
@@ -432,10 +441,29 @@ fn start_postgres(
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
compute.state_changed.notify_all();
compute.set_failed_status(err);
delay_exit = true;
}
}
}
// Set disk quota if the compute spec says so
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) =
(disk_quota_bytes, set_disk_quota_for_fs)
{
match set_disk_quota(disk_quota_bytes, &disk_quota_fs_mountpoint) {
Ok(()) => {
let size_mib = disk_quota_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%disk_quota_bytes, %size_mib, "set disk quota");
}
Err(err) => {
let err = err.context("failed to set disk quota");
error!("{err:#}");
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
compute.set_failed_status(err);
delay_exit = true;
}
}
@@ -450,16 +478,7 @@ fn start_postgres(
Ok(pg) => Some(pg),
Err(err) => {
error!("could not start the compute node: {:#}", err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{:?}", err));
state.status = ComputeStatus::Failed;
// Notify others that Postgres failed to start. In case of configuring the
// empty compute, it's likely that API handler is still waiting for compute
// state change. With this we will notify it that compute is in Failed state,
// so control plane will know about it earlier and record proper error instead
// of timeout.
compute.state_changed.notify_all();
drop(state); // unlock
compute.set_failed_status(err);
delay_exit = true;
None
}
@@ -750,6 +769,11 @@ fn cli() -> clap::Command {
.long("resize-swap-on-bind")
.action(clap::ArgAction::SetTrue),
)
.arg(
Arg::new("set-disk-quota-for-fs")
.long("set-disk-quota-for-fs")
.value_name("SET_DISK_QUOTA_FOR_FS")
)
}
/// When compute_ctl is killed, send also termination signal to sync-safekeepers

View File

@@ -10,6 +10,7 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::{Condvar, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use anyhow::{Context, Result};
@@ -305,6 +306,13 @@ impl ComputeNode {
self.state_changed.notify_all();
}
pub fn set_failed_status(&self, err: anyhow::Error) {
let mut state = self.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
self.state_changed.notify_all();
}
pub fn get_status(&self) -> ComputeStatus {
self.state.lock().unwrap().status
}
@@ -710,7 +718,7 @@ impl ComputeNode {
info!("running initdb");
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
Command::new(initdb_bin)
.args(["-D", pgdata])
.args(["--pgdata", pgdata])
.output()
.expect("cannot start initdb process");
@@ -1123,6 +1131,9 @@ impl ComputeNode {
//
// Use that as a default location and pattern, except macos where core dumps are written
// to /cores/ directory by default.
//
// With default Linux settings, the core dump file is called just "core", so check for
// that too.
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
@@ -1134,8 +1145,17 @@ impl ComputeNode {
let files = fs::read_dir(core_dump_dir)?;
let cores = files.filter_map(|entry| {
let entry = entry.ok()?;
let _ = entry.file_name().to_str()?.strip_prefix("core.")?;
Some(entry.path())
let is_core_dump = match entry.file_name().to_str()? {
n if n.starts_with("core.") => true,
"core" => true,
_ => false,
};
if is_core_dump {
Some(entry.path())
} else {
None
}
});
// Print backtrace for each core dump
@@ -1386,6 +1406,36 @@ LIMIT 100",
}
Ok(remote_ext_metrics)
}
/// Waits until current thread receives a state changed notification and
/// the pageserver connection strings has changed.
///
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();
if !unchanged {
info!("Pageserver config changed");
}
}
}
pub fn forward_termination_signal() {

View File

@@ -11,9 +11,17 @@ use crate::compute::ComputeNode;
fn configurator_main_loop(compute: &Arc<ComputeNode>) {
info!("waiting for reconfiguration requests");
loop {
let state = compute.state.lock().unwrap();
let mut state = compute.state_changed.wait(state).unwrap();
let mut state = compute.state.lock().unwrap();
// We have to re-check the status after re-acquiring the lock because it could be that
// the status has changed while we were waiting for the lock, and we might not need to
// wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e.
// we are waiting for a condition variable that will never be signaled.
if state.status != ComputeStatus::ConfigurationPending {
state = compute.state_changed.wait(state).unwrap();
}
// Re-check the status after waking up
if state.status == ComputeStatus::ConfigurationPending {
info!("got configuration request");
state.status = ComputeStatus::Configuration;

View File

@@ -0,0 +1,25 @@
use anyhow::Context;
pub const DISK_QUOTA_BIN: &str = "/neonvm/bin/set-disk-quota";
/// If size_bytes is 0, it disables the quota. Otherwise, it sets filesystem quota to size_bytes.
/// `fs_mountpoint` should point to the mountpoint of the filesystem where the quota should be set.
pub fn set_disk_quota(size_bytes: u64, fs_mountpoint: &str) -> anyhow::Result<()> {
let size_kb = size_bytes / 1024;
// run `/neonvm/bin/set-disk-quota {size_kb} {mountpoint}`
let child_result = std::process::Command::new("/usr/bin/sudo")
.arg(DISK_QUOTA_BIN)
.arg(size_kb.to_string())
.arg(fs_mountpoint)
.spawn();
child_result
.context("spawn() failed")
.and_then(|mut child| child.wait().context("wait() failed"))
.and_then(|status| match status.success() {
true => Ok(()),
false => Err(anyhow::anyhow!("process exited with {status}")),
})
// wrap any prior error with the overall context that we couldn't run the command
.with_context(|| format!("could not run `/usr/bin/sudo {DISK_QUOTA_BIN}`"))
}

View File

@@ -10,6 +10,7 @@ pub mod http;
pub mod logger;
pub mod catalog;
pub mod compute;
pub mod disk_quota;
pub mod extension_server;
pub mod lsn_lease;
mod migration;

View File

@@ -1,4 +1,3 @@
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*;
@@ -23,8 +22,7 @@ pub fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
.with_writer(std::io::stderr);
// Initialize OpenTelemetry
let otlp_layer =
tracing_utils::init_tracing_without_runtime("compute_ctl").map(OpenTelemetryLayer::new);
let otlp_layer = tracing_utils::init_tracing_without_runtime("compute_ctl");
// Put it all together
tracing_subscriber::registry()

View File

@@ -57,10 +57,10 @@ fn lsn_lease_bg_task(
.max(valid_duration / 2);
info!(
"Succeeded, sleeping for {} seconds",
"Request succeeded, sleeping for {} seconds",
sleep_duration.as_secs()
);
thread::sleep(sleep_duration);
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
}
}
@@ -89,10 +89,7 @@ fn acquire_lsn_lease_with_retry(
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
info!("Got storage auth token from spec file");
config.password(storage_auth_token.clone());
} else {
info!("Storage auth token not set");
}
config
})
@@ -108,9 +105,11 @@ fn acquire_lsn_lease_with_retry(
bail!("Permanent error: lease could not be obtained, LSN is behind the GC cutoff");
}
Err(e) => {
warn!("Failed to acquire lsn lease: {e} (attempt {attempts}");
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
thread::sleep(Duration::from_millis(retry_period_ms as u64));
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
retry_period_ms as u64,
));
retry_period_ms *= 1.5;
retry_period_ms = retry_period_ms.min(MAX_RETRY_PERIOD_MS);
}

View File

@@ -9,6 +9,7 @@ anyhow.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
futures.workspace = true
humantime.workspace = true
nix.workspace = true
once_cell.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
//! Branch mappings for convenience
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use anyhow::{bail, Context};
use serde::{Deserialize, Serialize};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
/// Keep human-readable aliases in memory (and persist them to config XXX), to hide tenant/timeline hex strings from the user.
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct BranchMappings {
/// Default tenant ID to use with the 'neon_local' command line utility, when
/// --tenant_id is not explicitly specified. This comes from the branches.
pub default_tenant_id: Option<TenantId>,
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
pub mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
}
impl BranchMappings {
pub fn register_branch_mapping(
&mut self,
branch_name: String,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> anyhow::Result<()> {
let existing_values = self.mappings.entry(branch_name.clone()).or_default();
let existing_ids = existing_values
.iter()
.find(|(existing_tenant_id, _)| existing_tenant_id == &tenant_id);
if let Some((_, old_timeline_id)) = existing_ids {
if old_timeline_id == &timeline_id {
Ok(())
} else {
bail!("branch '{branch_name}' is already mapped to timeline {old_timeline_id}, cannot map to another timeline {timeline_id}");
}
} else {
existing_values.push((tenant_id, timeline_id));
Ok(())
}
}
pub fn get_branch_timeline_id(
&self,
branch_name: &str,
tenant_id: TenantId,
) -> Option<TimelineId> {
// If it looks like a timeline ID, return it as it is
if let Ok(timeline_id) = branch_name.parse::<TimelineId>() {
return Some(timeline_id);
}
self.mappings
.get(branch_name)?
.iter()
.find(|(mapped_tenant_id, _)| mapped_tenant_id == &tenant_id)
.map(|&(_, timeline_id)| timeline_id)
.map(TimelineId::from)
}
pub fn timeline_name_mappings(&self) -> HashMap<TenantTimelineId, String> {
self.mappings
.iter()
.flat_map(|(name, tenant_timelines)| {
tenant_timelines.iter().map(|&(tenant_id, timeline_id)| {
(TenantTimelineId::new(tenant_id, timeline_id), name.clone())
})
})
.collect()
}
pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
let content = &toml::to_string_pretty(self)?;
fs::write(path, content).with_context(|| {
format!(
"Failed to write branch information into path '{}'",
path.display()
)
})
}
pub fn load(path: &Path) -> anyhow::Result<BranchMappings> {
let branches_file_contents = fs::read_to_string(path)?;
Ok(toml::from_str(branches_file_contents.as_str())?)
}
}

View File

@@ -561,6 +561,7 @@ impl Endpoint {
operation_uuid: None,
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -113,7 +113,7 @@ impl SafekeeperNode {
pub async fn start(
&self,
extra_opts: Vec<String>,
extra_opts: &[String],
retry_timeout: &Duration,
) -> anyhow::Result<()> {
print!(
@@ -196,7 +196,7 @@ impl SafekeeperNode {
]);
}
args.extend(extra_opts);
args.extend_from_slice(extra_opts);
background_process::start_process(
&format!("safekeeper-{id}"),

View File

@@ -347,7 +347,7 @@ impl StorageController {
if !tokio::fs::try_exists(&pg_data_path).await? {
let initdb_args = [
"-D",
"--pgdata",
pg_data_path.as_ref(),
"--username",
&username(),

View File

@@ -50,6 +50,16 @@ pub struct ComputeSpec {
#[serde(default)]
pub swap_size_bytes: Option<u64>,
/// If compute_ctl was passed `--set-disk-quota-for-fs`, a value of `Some(_)` instructs
/// compute_ctl to run `/neonvm/bin/set-disk-quota` with the given size and fs, when the
/// spec is first received.
///
/// Both this field and `--set-disk-quota-for-fs` are required, so that the control plane's
/// spec generation doesn't need to be aware of the actual compute it's running on, while
/// guaranteeing gradual rollout of disk quota.
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,
@@ -268,6 +278,22 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Configured the local-proxy application with the relevant JWKS and roles it should
/// use for authorizing connect requests using JWT.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct LocalProxySpec {
pub jwks: Vec<JwksSettings>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct JwksSettings {
pub id: String,
pub role_names: Vec<String>,
pub jwks_url: String,
pub provider_name: String,
pub jwt_audience: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -984,6 +984,7 @@ pub fn short_error(e: &QueryError) -> String {
}
fn log_query_error(query: &str, e: &QueryError) {
// If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {

View File

@@ -93,9 +93,9 @@ impl Conf {
);
let output = self
.new_pg_command("initdb")?
.arg("-D")
.arg("--pgdata")
.arg(&self.datadir)
.args(["-U", "postgres", "--no-instructions", "--no-sync"])
.args(["--username", "postgres", "--no-instructions", "--no-sync"])
.output()?;
debug!("initdb output: {:?}", output);
ensure!(

View File

@@ -6,12 +6,14 @@ license.workspace = true
[dependencies]
hyper.workspace = true
opentelemetry = { workspace = true, features=["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry = { workspace = true, features = ["trace"] }
opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
[dev-dependencies]
tracing-subscriber.workspace = true # For examples in docs

View File

@@ -10,7 +10,6 @@
//!
//! ```rust,no_run
//! use tracing_subscriber::prelude::*;
//! use tracing_opentelemetry::OpenTelemetryLayer;
//!
//! #[tokio::main]
//! async fn main() {
@@ -22,7 +21,7 @@
//! .with_writer(std::io::stderr);
//!
//! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces
//! let otlp_layer = tracing_utils::init_tracing("my_application").await.map(OpenTelemetryLayer::new);
//! let otlp_layer = tracing_utils::init_tracing("my_application").await;
//!
//! // Put it all together
//! tracing_subscriber::registry()
@@ -35,15 +34,15 @@
#![deny(unsafe_code)]
#![deny(clippy::undocumented_unsafe_blocks)]
use opentelemetry::sdk::Resource;
use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_otlp::{OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT};
pub use tracing_opentelemetry::OpenTelemetryLayer;
pub mod http;
use opentelemetry::trace::TracerProvider;
use opentelemetry::KeyValue;
use opentelemetry_sdk::Resource;
use tracing::Subscriber;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::Layer;
/// Set up OpenTelemetry exporter, using configuration from environment variables.
///
/// `service_name` is set as the OpenTelemetry 'service.name' resource (see
@@ -71,7 +70,10 @@ pub mod http;
///
/// This doesn't block, but is marked as 'async' to hint that this must be called in
/// asynchronous execution context.
pub async fn init_tracing(service_name: &str) -> Option<opentelemetry::sdk::trace::Tracer> {
pub async fn init_tracing<S>(service_name: &str) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -80,9 +82,10 @@ pub async fn init_tracing(service_name: &str) -> Option<opentelemetry::sdk::trac
/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing
/// tasks.
pub fn init_tracing_without_runtime(
service_name: &str,
) -> Option<opentelemetry::sdk::trace::Tracer> {
pub fn init_tracing_without_runtime<S>(service_name: &str) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -113,54 +116,36 @@ pub fn init_tracing_without_runtime(
Some(init_tracing_internal(service_name.to_string()))
}
fn init_tracing_internal(service_name: String) -> opentelemetry::sdk::trace::Tracer {
// Set up exporter from the OTEL_EXPORTER_* environment variables
let mut exporter = opentelemetry_otlp::new_exporter().http().with_env();
fn init_tracing_internal<S>(service_name: String) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
// Sets up exporter from the OTEL_EXPORTER_* environment variables.
let exporter = opentelemetry_otlp::new_exporter().http();
// XXX opentelemetry-otlp v0.18.0 has a bug in how it uses the
// OTEL_EXPORTER_OTLP_ENDPOINT env variable. According to the
// OpenTelemetry spec at
// <https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md#endpoint-urls-for-otlphttp>,
// the full exporter URL is formed by appending "/v1/traces" to the value
// of OTEL_EXPORTER_OTLP_ENDPOINT. However, opentelemetry-otlp only does
// that with the grpc-tonic exporter. Other exporters, like the HTTP
// exporter, use the URL from OTEL_EXPORTER_OTLP_ENDPOINT as is, without
// appending "/v1/traces".
//
// See https://github.com/open-telemetry/opentelemetry-rust/pull/950
//
// Work around that by checking OTEL_EXPORTER_OTLP_ENDPOINT, and setting
// the endpoint url with the "/v1/traces" path ourselves. If the bug is
// fixed in a later version, we can remove this code. But if we don't
// remember to remove this, it won't do any harm either, as the crate will
// just ignore the OTEL_EXPORTER_OTLP_ENDPOINT setting when the endpoint
// is set directly with `with_endpoint`.
if std::env::var(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT).is_err() {
if let Ok(mut endpoint) = std::env::var(OTEL_EXPORTER_OTLP_ENDPOINT) {
if !endpoint.ends_with('/') {
endpoint.push('/');
}
endpoint.push_str("v1/traces");
exporter = exporter.with_endpoint(endpoint);
}
}
// TODO: opentelemetry::global::set_error_handler() with custom handler that
// bypasses default tracing layers, but logs regular looking log
// messages.
// Propagate trace information in the standard W3C TraceContext format.
opentelemetry::global::set_text_map_propagator(
opentelemetry::sdk::propagation::TraceContextPropagator::new(),
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
);
opentelemetry_otlp::new_pipeline()
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(exporter)
.with_trace_config(
opentelemetry::sdk::trace::config().with_resource(Resource::new(vec![KeyValue::new(
.with_trace_config(opentelemetry_sdk::trace::Config::default().with_resource(
Resource::new(vec![KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
service_name,
)])),
)
.install_batch(opentelemetry::runtime::Tokio)
)]),
))
.install_batch(opentelemetry_sdk::runtime::Tokio)
.expect("could not initialize opentelemetry exporter")
.tracer("global");
tracing_opentelemetry::layer().with_tracer(tracer)
}
// Shutdown trace pipeline gracefully, so that it has a chance to send any

View File

@@ -736,4 +736,22 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_init_lsn_lease(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<LsnLease> {
let uri = format!(
"{}/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/lsn_lease",
self.mgmt_api_endpoint,
);
self.request(Method::POST, &uri, LsnLeaseRequest { lsn })
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -15,7 +15,7 @@ use clap::{Arg, ArgAction, Command};
use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::config::PageserverIdentity;
use pageserver::control_plane_client::ControlPlaneClient;
use pageserver::controller_upcall_client::ControllerUpcallClient;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::{COMPUTE_REQUEST_RUNTIME, WALRECEIVER_RUNTIME};
@@ -396,7 +396,7 @@ fn start_pageserver(
// Set up deletion queue
let (deletion_queue, deletion_workers) = DeletionQueue::new(
remote_storage.clone(),
ControlPlaneClient::new(conf, &shutdown_pageserver),
ControllerUpcallClient::new(conf, &shutdown_pageserver),
conf,
);
if let Some(deletion_workers) = deletion_workers {

View File

@@ -17,9 +17,12 @@ use utils::{backoff, failpoint_support, generation::Generation, id::NodeId};
use crate::{config::PageServerConf, virtual_file::on_fatal_io_error};
use pageserver_api::config::NodeMetadata;
/// The Pageserver's client for using the control plane API: this is a small subset
/// of the overall control plane API, for dealing with generations (see docs/rfcs/025-generation-numbers.md)
pub struct ControlPlaneClient {
/// The Pageserver's client for using the storage controller upcall API: this is a small API
/// for dealing with generations (see docs/rfcs/025-generation-numbers.md).
///
/// The server presenting this API may either be the storage controller or some other
/// service (such as the Neon control plane) providing a store of generation numbers.
pub struct ControllerUpcallClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
@@ -45,7 +48,7 @@ pub trait ControlPlaneGenerationsApi {
) -> impl Future<Output = Result<HashMap<TenantShardId, bool>, RetryForeverError>> + Send;
}
impl ControlPlaneClient {
impl ControllerUpcallClient {
/// A None return value indicates that the input `conf` object does not have control
/// plane API enabled.
pub fn new(conf: &'static PageServerConf, cancel: &CancellationToken) -> Option<Self> {
@@ -114,7 +117,7 @@ impl ControlPlaneClient {
}
}
impl ControlPlaneGenerationsApi for ControlPlaneClient {
impl ControlPlaneGenerationsApi for ControllerUpcallClient {
/// Block until we get a successful response, or error out if we are shut down
async fn re_attach(
&self,
@@ -216,29 +219,38 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
.join("validate")
.expect("Failed to build validate path");
let request = ValidateRequest {
tenants: tenants
.into_iter()
.map(|(id, gen)| ValidateRequestTenant {
id,
gen: gen
.into()
.expect("Generation should always be valid for a Tenant doing deletions"),
})
.collect(),
};
// When sending validate requests, break them up into chunks so that we
// avoid possible edge cases of generating any HTTP requests that
// require database I/O across many thousands of tenants.
let mut result: HashMap<TenantShardId, bool> = HashMap::with_capacity(tenants.len());
for tenant_chunk in (tenants).chunks(128) {
let request = ValidateRequest {
tenants: tenant_chunk
.iter()
.map(|(id, generation)| ValidateRequestTenant {
id: *id,
gen: (*generation).into().expect(
"Generation should always be valid for a Tenant doing deletions",
),
})
.collect(),
};
failpoint_support::sleep_millis_async!("control-plane-client-validate-sleep", &self.cancel);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
failpoint_support::sleep_millis_async!(
"control-plane-client-validate-sleep",
&self.cancel
);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
}
let response: ValidateResponse =
self.retry_http_forever(&re_attach_path, request).await?;
for rt in response.tenants {
result.insert(rt.id, rt.valid);
}
}
let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?;
Ok(response
.tenants
.into_iter()
.map(|rt| (rt.id, rt.valid))
.collect())
Ok(result.into_iter().collect())
}
}

View File

@@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::metrics;
use crate::tenant::remote_timeline_client::remote_layer_path;
use crate::tenant::remote_timeline_client::remote_timeline_path;
@@ -622,7 +622,7 @@ impl DeletionQueue {
/// If remote_storage is None, then the returned workers will also be None.
pub fn new<C>(
remote_storage: GenericRemoteStorage,
control_plane_client: Option<C>,
controller_upcall_client: Option<C>,
conf: &'static PageServerConf,
) -> (Self, Option<DeletionQueueWorkers<C>>)
where
@@ -662,7 +662,7 @@ impl DeletionQueue {
conf,
backend_rx,
executor_tx,
control_plane_client,
controller_upcall_client,
lsn_table.clone(),
cancel.clone(),
),
@@ -704,7 +704,7 @@ mod test {
use tokio::task::JoinHandle;
use crate::{
control_plane_client::RetryForeverError,
controller_upcall_client::RetryForeverError,
repository::Key,
tenant::{harness::TenantHarness, storage_layer::DeltaLayerName},
};

View File

@@ -25,8 +25,8 @@ use tracing::info;
use tracing::warn;
use crate::config::PageServerConf;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::control_plane_client::RetryForeverError;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::controller_upcall_client::RetryForeverError;
use crate::metrics;
use crate::virtual_file::MaybeFatalIo;
@@ -61,7 +61,7 @@ where
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
// Client for calling into control plane API for validation of deletes
control_plane_client: Option<C>,
controller_upcall_client: Option<C>,
// DeletionLists which are waiting generation validation. Not safe to
// execute until [`validate`] has processed them.
@@ -94,7 +94,7 @@ where
conf: &'static PageServerConf,
rx: tokio::sync::mpsc::Receiver<ValidatorQueueMessage>,
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
control_plane_client: Option<C>,
controller_upcall_client: Option<C>,
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
cancel: CancellationToken,
) -> Self {
@@ -102,7 +102,7 @@ where
conf,
rx,
tx,
control_plane_client,
controller_upcall_client,
lsn_table,
pending_lists: Vec::new(),
validated_lists: Vec::new(),
@@ -145,8 +145,8 @@ where
return Ok(());
}
let tenants_valid = if let Some(control_plane_client) = &self.control_plane_client {
match control_plane_client
let tenants_valid = if let Some(controller_upcall_client) = &self.controller_upcall_client {
match controller_upcall_client
.validate(tenant_generations.iter().map(|(k, v)| (*k, *v)).collect())
.await
{

View File

@@ -56,6 +56,7 @@ use utils::http::endpoint::request_span;
use utils::http::request::must_parse_query_param;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::deletion_queue::DeletionQueueClient;
use crate::pgdatadir_mapping::LsnForTimestamp;
@@ -80,7 +81,6 @@ use crate::tenant::timeline::CompactionError;
use crate::tenant::timeline::Timeline;
use crate::tenant::GetTimelineError;
use crate::tenant::{LogicalSizeCalculationCause, PageReconstructError};
use crate::{config::PageServerConf, tenant::mgr};
use crate::{disk_usage_eviction_task, tenant};
use pageserver_api::models::{
StatusResponse, TenantConfigRequest, TenantInfo, TimelineCreateRequest, TimelineGcRequest,
@@ -824,7 +824,7 @@ async fn get_lsn_by_timestamp_handler(
let lease = if with_lease {
timeline
.make_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.init_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.inspect_err(|_| {
warn!("fail to grant a lease to {}", lsn);
})
@@ -1692,9 +1692,18 @@ async fn lsn_lease_handler(
let timeline =
active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id)
.await?;
let result = timeline
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| ApiError::InternalServerError(e.context("lsn lease http handler")))?;
let result = async {
timeline
.init_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| {
ApiError::InternalServerError(
e.context(format!("invalid lsn lease request at {lsn}")),
)
})
}
.instrument(info_span!("init_lsn_lease", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %timeline_id))
.await?;
json_response(StatusCode::OK, result)
}
@@ -1710,8 +1719,13 @@ async fn timeline_gc_handler(
let gc_req: TimelineGcRequest = json_request(&mut request).await?;
let state = get_state(&request);
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let gc_result = mgr::immediate_gc(tenant_shard_id, timeline_id, gc_req, cancel, &ctx).await?;
let gc_result = state
.tenant_manager
.immediate_gc(tenant_shard_id, timeline_id, gc_req, cancel, &ctx)
.await?;
json_response(StatusCode::OK, gc_result)
}

View File

@@ -6,7 +6,7 @@ pub mod basebackup;
pub mod config;
pub mod consumption_metrics;
pub mod context;
pub mod control_plane_client;
pub mod controller_upcall_client;
pub mod deletion_queue;
pub mod disk_usage_eviction_task;
pub mod http;

View File

@@ -8,6 +8,8 @@ use metrics::{
};
use once_cell::sync::Lazy;
use pageserver_api::shard::TenantShardId;
use postgres_backend::{is_expected_io_error, QueryError};
use pq_proto::framed::ConnectionError;
use strum::{EnumCount, VariantNames};
use strum_macros::{IntoStaticStr, VariantNames};
use tracing::warn;
@@ -1508,6 +1510,7 @@ static COMPUTE_STARTUP_BUCKETS: Lazy<[f64; 28]> = Lazy::new(|| {
pub(crate) struct BasebackupQueryTime {
ok: Histogram,
error: Histogram,
client_error: Histogram,
}
pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|| {
@@ -1521,6 +1524,7 @@ pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|
BasebackupQueryTime {
ok: vec.get_metric_with_label_values(&["ok"]).unwrap(),
error: vec.get_metric_with_label_values(&["error"]).unwrap(),
client_error: vec.get_metric_with_label_values(&["client_error"]).unwrap(),
}
});
@@ -1557,7 +1561,7 @@ impl BasebackupQueryTime {
}
impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
pub(crate) fn observe<T, E>(self, res: &Result<T, E>) {
pub(crate) fn observe<T>(self, res: &Result<T, QueryError>) {
let elapsed = self.start.elapsed();
let ex_throttled = self
.ctx
@@ -1576,10 +1580,15 @@ impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
elapsed
}
};
let metric = if res.is_ok() {
&self.parent.ok
} else {
&self.parent.error
// If you want to change categorize of a specific error, also change it in `log_query_error`.
let metric = match res {
Ok(_) => &self.parent.ok,
Err(QueryError::Disconnected(ConnectionError::Io(io_error)))
if is_expected_io_error(io_error) =>
{
&self.parent.client_error
}
Err(_) => &self.parent.error,
};
metric.observe(ex_throttled.as_secs_f64());
}

View File

@@ -273,10 +273,20 @@ async fn page_service_conn_main(
info!("Postgres client disconnected ({io_error})");
Ok(())
} else {
Err(io_error).context("Postgres connection error")
let tenant_id = conn_handler.timeline_handles.tenant_id();
Err(io_error).context(format!(
"Postgres connection error for tenant_id={:?} client at peer_addr={}",
tenant_id, peer_addr
))
}
}
other => other.context("Postgres query error"),
other => {
let tenant_id = conn_handler.timeline_handles.tenant_id();
other.context(format!(
"Postgres query error for tenant_id={:?} client peer_addr={}",
tenant_id, peer_addr
))
}
}
}
@@ -340,6 +350,10 @@ impl TimelineHandles {
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
self.wrapper.tenant_id.get().copied()
}
}
pub(crate) struct TenantManagerWrapper {
@@ -819,7 +833,7 @@ impl PageServerHandler {
set_tracing_field_shard_id(&timeline);
let lease = timeline
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.renew_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.inspect_err(|e| {
warn!("{e}");
})

View File

@@ -21,6 +21,7 @@ use futures::stream::FuturesUnordered;
use futures::StreamExt;
use pageserver_api::models;
use pageserver_api::models::AuxFilePolicy;
use pageserver_api::models::LsnLease;
use pageserver_api::models::TimelineArchivalState;
use pageserver_api::models::TimelineState;
use pageserver_api::models::TopTenantShardItem;
@@ -182,27 +183,54 @@ pub struct TenantSharedResources {
pub(super) struct AttachedTenantConf {
tenant_conf: TenantConfOpt,
location: AttachedLocationConfig,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<tokio::time::Instant>,
}
impl AttachedTenantConf {
fn new(tenant_conf: TenantConfOpt, location: AttachedLocationConfig) -> Self {
// Sets a deadline before which we cannot proceed to GC due to lsn lease.
//
// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
// length, we guarantee that all the leases we granted before will have a chance to renew
// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
let lsn_lease_deadline = if location.attach_mode == AttachmentMode::Single {
Some(
tokio::time::Instant::now()
+ tenant_conf
.lsn_lease_length
.unwrap_or(LsnLease::DEFAULT_LENGTH),
)
} else {
// We don't use `lsn_lease_deadline` to delay GC in AttachedMulti and AttachedStale
// because we don't do GC in these modes.
None
};
Self {
tenant_conf,
location,
lsn_lease_deadline,
}
}
fn try_from(location_conf: LocationConf) -> anyhow::Result<Self> {
match &location_conf.mode {
LocationMode::Attached(attach_conf) => Ok(Self {
tenant_conf: location_conf.tenant_conf,
location: *attach_conf,
}),
LocationMode::Attached(attach_conf) => {
Ok(Self::new(location_conf.tenant_conf, *attach_conf))
}
LocationMode::Secondary(_) => {
anyhow::bail!("Attempted to construct AttachedTenantConf from a LocationConf in secondary mode")
}
}
}
fn is_gc_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| tokio::time::Instant::now() < d)
.unwrap_or(false)
}
}
struct TimelinePreload {
timeline_id: TimelineId,
@@ -1822,6 +1850,11 @@ impl Tenant {
info!("Skipping GC in location state {:?}", conf.location);
return Ok(GcResult::default());
}
if conf.is_gc_blocked_by_lsn_lease_deadline() {
info!("Skipping GC because lsn lease deadline is not reached");
return Ok(GcResult::default());
}
}
let _guard = match self.gc_block.start().await {
@@ -2630,6 +2663,8 @@ impl Tenant {
Arc::new(AttachedTenantConf {
tenant_conf: new_tenant_conf.clone(),
location: inner.location,
// Attached location is not changed, no need to update lsn lease deadline.
lsn_lease_deadline: inner.lsn_lease_deadline,
})
});
@@ -3887,9 +3922,9 @@ async fn run_initdb(
let _permit = INIT_DB_SEMAPHORE.acquire().await;
let initdb_command = tokio::process::Command::new(&initdb_bin_path)
.args(["-D", initdb_target_dir.as_ref()])
.args(["-U", &conf.superuser])
.args(["-E", "utf8"])
.args(["--pgdata", initdb_target_dir.as_ref()])
.args(["--username", &conf.superuser])
.args(["--encoding", "utf8"])
.arg("--no-instructions")
.arg("--no-sync")
.env_clear()
@@ -4461,13 +4496,17 @@ mod tests {
tline.freeze_and_flush().await.map_err(|e| e.into())
}
#[tokio::test]
#[tokio::test(start_paused = true)]
async fn test_prohibit_branch_creation_on_garbage_collected_data() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_prohibit_branch_creation_on_garbage_collected_data")
.await?
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let tline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await?;
@@ -7244,9 +7283,17 @@ mod tests {
Ok(())
}
#[tokio::test]
#[tokio::test(start_paused = true)]
async fn test_lsn_lease() -> anyhow::Result<()> {
let (tenant, ctx) = TenantHarness::create("test_lsn_lease").await?.load().await;
let (tenant, ctx) = TenantHarness::create("test_lsn_lease")
.await
.unwrap()
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let key = Key::from_hex("010000000033333333444444445500000000").unwrap();
let end_lsn = Lsn(0x100);
@@ -7274,24 +7321,33 @@ mod tests {
let leased_lsns = [0x30, 0x50, 0x70];
let mut leases = Vec::new();
let _: anyhow::Result<_> = leased_lsns.iter().try_for_each(|n| {
leases.push(timeline.make_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)?);
Ok(())
leased_lsns.iter().for_each(|n| {
leases.push(
timeline
.init_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)
.expect("lease request should succeed"),
);
});
// Renewing with shorter lease should not change the lease.
let updated_lease_0 =
timeline.make_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)?;
assert_eq!(updated_lease_0.valid_until, leases[0].valid_until);
let updated_lease_0 = timeline
.renew_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)
.expect("lease renewal should succeed");
assert_eq!(
updated_lease_0.valid_until, leases[0].valid_until,
" Renewing with shorter lease should not change the lease."
);
// Renewing with a long lease should renew lease with later expiration time.
let updated_lease_1 = timeline.make_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)?;
assert!(updated_lease_1.valid_until > leases[1].valid_until);
let updated_lease_1 = timeline
.renew_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)
.expect("lease renewal should succeed");
assert!(
updated_lease_1.valid_until > leases[1].valid_until,
"Renewing with a long lease should renew lease with later expiration time."
);
// Force set disk consistent lsn so we can get the cutoff at `end_lsn`.
info!(
@@ -7308,7 +7364,8 @@ mod tests {
&CancellationToken::new(),
&ctx,
)
.await?;
.await
.unwrap();
// Keeping everything <= Lsn(0x80) b/c leases:
// 0/10: initdb layer
@@ -7322,13 +7379,16 @@ mod tests {
// Make lease on a already GC-ed LSN.
// 0/80 does not have a valid lease + is below latest_gc_cutoff
assert!(Lsn(0x80) < *timeline.get_latest_gc_cutoff_lsn());
let res = timeline.make_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx);
assert!(res.is_err());
timeline
.init_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx)
.expect_err("lease request on GC-ed LSN should fail");
// Should still be able to renew a currently valid lease
// Assumption: original lease to is still valid for 0/50.
let _ =
timeline.make_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)?;
// (use `Timeline::init_lsn_lease` for testing so it always does validation)
timeline
.init_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)
.expect("lease renewal with validation should succeed");
Ok(())
}

View File

@@ -1,29 +1,12 @@
use std::{collections::HashMap, time::Duration};
use std::collections::HashMap;
use super::remote_timeline_client::index::GcBlockingReason;
use tokio::time::Instant;
use utils::id::TimelineId;
type TimelinesBlocked = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
use super::remote_timeline_client::index::GcBlockingReason;
#[derive(Default)]
struct Storage {
timelines_blocked: TimelinesBlocked,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<Instant>,
}
type Storage = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
impl Storage {
fn is_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| Instant::now() < d)
.unwrap_or(false)
}
}
/// GcBlock provides persistent (per-timeline) gc blocking and facilitates transient time based gc
/// blocking.
/// GcBlock provides persistent (per-timeline) gc blocking.
#[derive(Default)]
pub(crate) struct GcBlock {
/// The timelines which have current reasons to block gc.
@@ -66,17 +49,6 @@ impl GcBlock {
}
}
/// Sets a deadline before which we cannot proceed to GC due to lsn lease.
///
/// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
/// length, we guarantee that all the leases we granted before will have a chance to renew
/// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
pub(super) fn set_lsn_lease_deadline(&self, lsn_lease_length: Duration) {
let deadline = Instant::now() + lsn_lease_length;
let mut g = self.reasons.lock().unwrap();
g.lsn_lease_deadline = Some(deadline);
}
/// Describe the current gc blocking reasons.
///
/// TODO: make this json serializable.
@@ -102,7 +74,7 @@ impl GcBlock {
) -> anyhow::Result<bool> {
let (added, uploaded) = {
let mut g = self.reasons.lock().unwrap();
let set = g.timelines_blocked.entry(timeline.timeline_id).or_default();
let set = g.entry(timeline.timeline_id).or_default();
let added = set.insert(reason);
// LOCK ORDER: intentionally hold the lock, see self.reasons.
@@ -133,7 +105,7 @@ impl GcBlock {
let (remaining_blocks, uploaded) = {
let mut g = self.reasons.lock().unwrap();
match g.timelines_blocked.entry(timeline.timeline_id) {
match g.entry(timeline.timeline_id) {
Entry::Occupied(mut oe) => {
let set = oe.get_mut();
set.remove(reason);
@@ -147,7 +119,7 @@ impl GcBlock {
}
}
let remaining_blocks = g.timelines_blocked.len();
let remaining_blocks = g.len();
// LOCK ORDER: intentionally hold the lock while scheduling; see self.reasons
let uploaded = timeline
@@ -172,11 +144,11 @@ impl GcBlock {
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
let unblocked = {
let mut g = self.reasons.lock().unwrap();
if g.timelines_blocked.is_empty() {
if g.is_empty() {
return;
}
g.timelines_blocked.remove(&timeline.timeline_id);
g.remove(&timeline.timeline_id);
BlockingReasons::clean_and_summarize(g).is_none()
};
@@ -187,11 +159,10 @@ impl GcBlock {
}
/// Initialize with the non-deleted timelines of this tenant.
pub(crate) fn set_scanned(&self, scanned: TimelinesBlocked) {
pub(crate) fn set_scanned(&self, scanned: Storage) {
let mut g = self.reasons.lock().unwrap();
assert!(g.timelines_blocked.is_empty());
g.timelines_blocked
.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
assert!(g.is_empty());
g.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
if let Some(reasons) = BlockingReasons::clean_and_summarize(g) {
tracing::info!(summary=?reasons, "initialized with gc blocked");
@@ -205,7 +176,6 @@ pub(super) struct Guard<'a> {
#[derive(Debug)]
pub(crate) struct BlockingReasons {
tenant_blocked_by_lsn_lease_deadline: bool,
timelines: usize,
reasons: enumset::EnumSet<GcBlockingReason>,
}
@@ -214,8 +184,8 @@ impl std::fmt::Display for BlockingReasons {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"tenant_blocked_by_lsn_lease_deadline: {}, {} timelines block for {:?}",
self.tenant_blocked_by_lsn_lease_deadline, self.timelines, self.reasons
"{} timelines block for {:?}",
self.timelines, self.reasons
)
}
}
@@ -223,15 +193,13 @@ impl std::fmt::Display for BlockingReasons {
impl BlockingReasons {
fn clean_and_summarize(mut g: std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
let mut reasons = enumset::EnumSet::empty();
g.timelines_blocked.retain(|_key, value| {
g.retain(|_key, value| {
reasons = reasons.union(*value);
!value.is_empty()
});
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if !g.timelines_blocked.is_empty() || blocked_by_lsn_lease_deadline {
if !g.is_empty() {
Some(BlockingReasons {
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
timelines: g.len(),
reasons,
})
} else {
@@ -240,17 +208,14 @@ impl BlockingReasons {
}
fn summarize(g: &std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if g.timelines_blocked.is_empty() && !blocked_by_lsn_lease_deadline {
if g.is_empty() {
None
} else {
let reasons = g
.timelines_blocked
.values()
.fold(enumset::EnumSet::empty(), |acc, next| acc.union(*next));
Some(BlockingReasons {
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
timelines: g.len(),
reasons,
})
}

View File

@@ -30,8 +30,8 @@ use utils::{backoff, completion, crashsafe};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::control_plane_client::{
ControlPlaneClient, ControlPlaneGenerationsApi, RetryForeverError,
use crate::controller_upcall_client::{
ControlPlaneGenerationsApi, ControllerUpcallClient, RetryForeverError,
};
use crate::deletion_queue::DeletionQueueClient;
use crate::http::routes::ACTIVE_TENANT_TIMEOUT;
@@ -122,7 +122,7 @@ pub(crate) enum ShardSelector {
Known(ShardIndex),
}
/// A convenience for use with the re_attach ControlPlaneClient function: rather
/// A convenience for use with the re_attach ControllerUpcallClient function: rather
/// than the serializable struct, we build this enum that encapsulates
/// the invariant that attached tenants always have generations.
///
@@ -219,7 +219,11 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
+ TEMP_FILE_SUFFIX;
let tmp_path = path_with_suffix_extension(&path, &rand_suffix);
fs::rename(path.as_ref(), &tmp_path).await?;
fs::File::open(parent).await?.sync_all().await?;
fs::File::open(parent)
.await?
.sync_all()
.await
.maybe_fatal_err("safe_rename_tenant_dir")?;
Ok(tmp_path)
}
@@ -341,7 +345,7 @@ async fn init_load_generations(
"Emergency mode! Tenants will be attached unsafely using their last known generation"
);
emergency_generations(tenant_confs)
} else if let Some(client) = ControlPlaneClient::new(conf, cancel) {
} else if let Some(client) = ControllerUpcallClient::new(conf, cancel) {
info!("Calling control plane API to re-attach tenants");
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
match client.re_attach(conf).await {
@@ -949,12 +953,6 @@ impl TenantManager {
(LocationMode::Attached(attach_conf), Some(TenantSlot::Attached(tenant))) => {
match attach_conf.generation.cmp(&tenant.generation) {
Ordering::Equal => {
if attach_conf.attach_mode == AttachmentMode::Single {
tenant
.gc_block
.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
}
// A transition from Attached to Attached in the same generation, we may
// take our fast path and just provide the updated configuration
// to the tenant.
@@ -2199,6 +2197,82 @@ impl TenantManager {
Ok((wanted_bytes, shard_count as u32))
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), %timeline_id))]
pub(crate) async fn immediate_gc(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,
cancel: CancellationToken,
ctx: &RequestContext,
) -> Result<GcResult, ApiError> {
let tenant = {
let guard = self.tenants.read().unwrap();
guard
.get(&tenant_shard_id)
.cloned()
.with_context(|| format!("tenant {tenant_shard_id}"))
.map_err(|e| ApiError::NotFound(e.into()))?
};
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
// Use tenant's pitr setting
let pitr = tenant.get_pitr_interval();
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
// Run in task_mgr to avoid race with tenant_detach operation
let ctx: RequestContext =
ctx.detached_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let _gate_guard = tenant.gate.enter().map_err(|_| ApiError::ShuttingDown)?;
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)
.await;
// FIXME: `gc_iteration` can return an error for multiple reasons; we should handle it
// better once the types support it.
#[cfg(feature = "testing")]
{
// we need to synchronize with drop completion for python tests without polling for
// log messages
if let Ok(result) = result.as_mut() {
let mut js = tokio::task::JoinSet::new();
for layer in std::mem::take(&mut result.doomed_layers) {
js.spawn(layer.wait_drop());
}
tracing::info!(
total = js.len(),
"starting to wait for the gc'd layers to be dropped"
);
while let Some(res) = js.join_next().await {
res.expect("wait_drop should not panic");
}
}
let timeline = tenant.get_timeline(timeline_id, false).ok();
let rtc = timeline.as_ref().map(|x| &x.remote_client);
if let Some(rtc) = rtc {
// layer drops schedule actions on remote timeline client to actually do the
// deletions; don't care about the shutdown error, just exit fast
drop(rtc.wait_completion().await);
}
}
result.map_err(|e| match e {
GcError::TenantCancelled | GcError::TimelineCancelled => ApiError::ShuttingDown,
GcError::TimelineNotFound => {
ApiError::NotFound(anyhow::anyhow!("Timeline not found").into())
}
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
}
#[derive(Debug, thiserror::Error)]
@@ -2343,7 +2417,7 @@ enum TenantSlotDropError {
/// Errors that can happen any time we are walking the tenant map to try and acquire
/// the TenantSlot for a particular tenant.
#[derive(Debug, thiserror::Error)]
pub enum TenantMapError {
pub(crate) enum TenantMapError {
// Tried to read while initializing
#[error("tenant map is still initializing")]
StillInitializing,
@@ -2373,7 +2447,7 @@ pub enum TenantMapError {
/// The `old_value` may be dropped before the SlotGuard is dropped, by calling
/// `drop_old_value`. It is an error to call this without shutting down
/// the conents of `old_value`.
pub struct SlotGuard {
pub(crate) struct SlotGuard {
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
upserted: bool,
@@ -2766,81 +2840,6 @@ use {
utils::http::error::ApiError,
};
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), %timeline_id))]
pub(crate) async fn immediate_gc(
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,
cancel: CancellationToken,
ctx: &RequestContext,
) -> Result<GcResult, ApiError> {
let tenant = {
let guard = TENANTS.read().unwrap();
guard
.get(&tenant_shard_id)
.cloned()
.with_context(|| format!("tenant {tenant_shard_id}"))
.map_err(|e| ApiError::NotFound(e.into()))?
};
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
// Use tenant's pitr setting
let pitr = tenant.get_pitr_interval();
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
// Run in task_mgr to avoid race with tenant_detach operation
let ctx: RequestContext =
ctx.detached_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let _gate_guard = tenant.gate.enter().map_err(|_| ApiError::ShuttingDown)?;
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)
.await;
// FIXME: `gc_iteration` can return an error for multiple reasons; we should handle it
// better once the types support it.
#[cfg(feature = "testing")]
{
// we need to synchronize with drop completion for python tests without polling for
// log messages
if let Ok(result) = result.as_mut() {
let mut js = tokio::task::JoinSet::new();
for layer in std::mem::take(&mut result.doomed_layers) {
js.spawn(layer.wait_drop());
}
tracing::info!(
total = js.len(),
"starting to wait for the gc'd layers to be dropped"
);
while let Some(res) = js.join_next().await {
res.expect("wait_drop should not panic");
}
}
let timeline = tenant.get_timeline(timeline_id, false).ok();
let rtc = timeline.as_ref().map(|x| &x.remote_client);
if let Some(rtc) = rtc {
// layer drops schedule actions on remote timeline client to actually do the
// deletions; don't care about the shutdown error, just exit fast
drop(rtc.wait_completion().await);
}
}
result.map_err(|e| match e {
GcError::TenantCancelled | GcError::TimelineCancelled => ApiError::ShuttingDown,
GcError::TimelineNotFound => {
ApiError::NotFound(anyhow::anyhow!("Timeline not found").into())
}
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;

View File

@@ -178,6 +178,7 @@ async fn download_object<'a>(
destination_file
.flush()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("flush source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -185,6 +186,7 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -232,6 +234,7 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;

View File

@@ -40,11 +40,11 @@ use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadCoalesceMode, VectoredReadPlanner,
VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::{self, VirtualFile};
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::{walrecord, TEMP_FILE_SUFFIX};
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
use anyhow::{anyhow, bail, ensure, Context, Result};
@@ -589,7 +589,9 @@ impl DeltaLayerWriterInner {
);
// fsync the file
file.sync_all().await?;
file.sync_all()
.await
.maybe_fatal_err("delta_layer sync_all")?;
trace!("created delta layer {}", self.path);
@@ -1133,7 +1135,7 @@ impl DeltaLayerInner {
ctx: &RequestContext,
) -> anyhow::Result<usize> {
use crate::tenant::vectored_blob_io::{
BlobMeta, VectoredReadBuilder, VectoredReadExtended,
BlobMeta, ChunkedVectoredReadBuilder, VectoredReadExtended,
};
use futures::stream::TryStreamExt;
@@ -1183,8 +1185,8 @@ impl DeltaLayerInner {
let mut prev: Option<(Key, Lsn, BlobRef)> = None;
let mut read_builder: Option<VectoredReadBuilder> = None;
let read_mode = VectoredReadCoalesceMode::get();
let mut read_builder: Option<ChunkedVectoredReadBuilder> = None;
let align = virtual_file::get_io_buffer_alignment();
let max_read_size = self
.max_vectored_read_bytes
@@ -1228,12 +1230,12 @@ impl DeltaLayerInner {
{
None
} else {
read_builder.replace(VectoredReadBuilder::new(
read_builder.replace(ChunkedVectoredReadBuilder::new(
offsets.start.pos(),
offsets.end.pos(),
meta,
max_read_size,
read_mode,
align,
))
}
} else {

View File

@@ -41,7 +41,7 @@ use crate::tenant::vectored_blob_io::{
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::virtual_file::{self, VirtualFile};
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
@@ -889,7 +889,9 @@ impl ImageLayerWriterInner {
// set inner.file here. The first read will have to re-open it.
// fsync the file
file.sync_all().await?;
file.sync_all()
.await
.maybe_fatal_err("image_layer sync_all")?;
trace!("created image layer {}", self.path);

View File

@@ -330,7 +330,6 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
RequestContext::todo_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let mut first = true;
tenant.gc_block.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
loop {
tokio::select! {
_ = cancel.cancelled() => {

View File

@@ -66,6 +66,7 @@ use std::{
use crate::{
aux_file::AuxFileSizeEstimator,
tenant::{
config::AttachmentMode,
layer_map::{LayerMap, SearchResult},
metadata::TimelineMetadata,
storage_layer::{inmemory_layer::IndexEntry, PersistentLayerDesc},
@@ -1324,16 +1325,38 @@ impl Timeline {
Ok(())
}
/// Obtains a temporary lease blocking garbage collection for the given LSN.
///
/// This function will error if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is also
/// no existing lease to renew. If there is an existing lease in the map, the lease will be renewed only if
/// the request extends the lease. The returned lease is therefore the maximum between the existing lease and
/// the requesting lease.
pub(crate) fn make_lsn_lease(
/// Initializes an LSN lease. The function will return an error if the requested LSN is less than the `latest_gc_cutoff_lsn`.
pub(crate) fn init_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, true, ctx)
}
/// Renews a lease at a particular LSN. The requested LSN is not validated against the `latest_gc_cutoff_lsn` when we are in the grace period.
pub(crate) fn renew_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, false, ctx)
}
/// Obtains a temporary lease blocking garbage collection for the given LSN.
///
/// If we are in `AttachedSingle` mode and is not blocked by the lsn lease deadline, this function will error
/// if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is no existing request present.
///
/// If there is an existing lease in the map, the lease will be renewed only if the request extends the lease.
/// The returned lease is therefore the maximum between the existing lease and the requesting lease.
fn make_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
init: bool,
_ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
let lease = {
@@ -1347,8 +1370,8 @@ impl Timeline {
let entry = gc_info.leases.entry(lsn);
let lease = {
if let Entry::Occupied(mut occupied) = entry {
match entry {
Entry::Occupied(mut occupied) => {
let existing_lease = occupied.get_mut();
if valid_until > existing_lease.valid_until {
existing_lease.valid_until = valid_until;
@@ -1360,20 +1383,28 @@ impl Timeline {
}
existing_lease.clone()
} else {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff)
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
Entry::Vacant(vacant) => {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff) if we are in AttachedSingle and
// not blocked by the lsn lease deadline.
let validate = {
let conf = self.tenant_conf.load();
conf.location.attach_mode == AttachmentMode::Single
&& !conf.is_gc_blocked_by_lsn_lease_deadline()
};
if init || validate {
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
}
let dt: DateTime<Utc> = valid_until.into();
info!("lease created, valid until {}", dt);
entry.or_insert(LsnLease { valid_until }).clone()
vacant.insert(LsnLease { valid_until }).clone()
}
};
lease
}
};
Ok(lease)
@@ -1950,8 +1981,6 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.lsn_lease_length)
}
// TODO(yuchen): remove unused flag after implementing https://github.com/neondatabase/neon/issues/8072
#[allow(unused)]
pub(crate) fn get_lsn_lease_length_for_ts(&self) -> Duration {
let tenant_conf = self.tenant_conf.load();
tenant_conf

View File

@@ -185,171 +185,7 @@ pub(crate) enum VectoredReadExtended {
No,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum VectoredReadCoalesceMode {
/// Only coalesce exactly adjacent reads.
AdjacentOnly,
/// In addition to adjacent reads, also consider reads whose corresponding
/// `end` and `start` offsets reside at the same chunk.
Chunked(usize),
}
impl VectoredReadCoalesceMode {
/// [`AdjacentVectoredReadBuilder`] is used if alignment requirement is 0,
/// whereas [`ChunkedVectoredReadBuilder`] is used for alignment requirement 1 and higher.
pub(crate) fn get() -> Self {
let align = virtual_file::get_io_buffer_alignment_raw();
if align == 0 {
VectoredReadCoalesceMode::AdjacentOnly
} else {
VectoredReadCoalesceMode::Chunked(align)
}
}
}
pub(crate) enum VectoredReadBuilder {
Adjacent(AdjacentVectoredReadBuilder),
Chunked(ChunkedVectoredReadBuilder),
}
impl VectoredReadBuilder {
fn new_impl(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
mode: VectoredReadCoalesceMode,
) -> Self {
match mode {
VectoredReadCoalesceMode::AdjacentOnly => Self::Adjacent(
AdjacentVectoredReadBuilder::new(start_offset, end_offset, meta, max_read_size),
),
VectoredReadCoalesceMode::Chunked(chunk_size) => {
Self::Chunked(ChunkedVectoredReadBuilder::new(
start_offset,
end_offset,
meta,
max_read_size,
chunk_size,
))
}
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), mode)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, mode)
}
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.extend(start, end, meta),
VectoredReadBuilder::Chunked(builder) => builder.extend(start, end, meta),
}
}
pub(crate) fn build(self) -> VectoredRead {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.build(),
VectoredReadBuilder::Chunked(builder) => builder.build(),
}
}
pub(crate) fn size(&self) -> usize {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.size(),
VectoredReadBuilder::Chunked(builder) => builder.size(),
}
}
}
pub(crate) struct AdjacentVectoredReadBuilder {
/// Start offset of the read.
start: u64,
// End offset of the read.
end: u64,
/// Start offset and metadata for each blob in this read
blobs_at: VecMap<u64, BlobMeta>,
max_read_size: Option<usize>,
}
impl AdjacentVectoredReadBuilder {
/// Start building a new vectored read.
///
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
) -> Self {
let mut blobs_at = VecMap::default();
blobs_at
.append(start_offset, meta)
.expect("First insertion always succeeds");
Self {
start: start_offset,
end: end_offset,
blobs_at,
max_read_size,
}
}
/// Attempt to extend the current read with a new blob if the start
/// offset matches with the current end of the vectored read
/// and the resuting size is below the max read size
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
tracing::trace!(start, end, "trying to extend");
let size = (end - start) as usize;
let not_limited_by_max_read_size = {
if let Some(max_read_size) = self.max_read_size {
self.size() + size <= max_read_size
} else {
true
}
};
if self.end == start && not_limited_by_max_read_size {
self.end = end;
self.blobs_at
.append(start, meta)
.expect("LSNs are ordered within vectored reads");
return VectoredReadExtended::Yes;
}
VectoredReadExtended::No
}
pub(crate) fn size(&self) -> usize {
(self.end - self.start) as usize
}
pub(crate) fn build(self) -> VectoredRead {
VectoredRead {
start: self.start,
end: self.end,
blobs_at: self.blobs_at,
}
}
}
/// A vectored read builder that tries to coalesce all reads that fits in a chunk.
pub(crate) struct ChunkedVectoredReadBuilder {
/// Start block number
start_blk_no: usize,
@@ -373,7 +209,7 @@ impl ChunkedVectoredReadBuilder {
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
pub(crate) fn new(
fn new_impl(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
@@ -396,6 +232,25 @@ impl ChunkedVectoredReadBuilder {
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, align)
}
/// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk.
///
/// The resulting size also must be below the max read size.
@@ -474,17 +329,17 @@ pub struct VectoredReadPlanner {
max_read_size: usize,
mode: VectoredReadCoalesceMode,
align: usize,
}
impl VectoredReadPlanner {
pub fn new(max_read_size: usize) -> Self {
let mode = VectoredReadCoalesceMode::get();
let align = virtual_file::get_io_buffer_alignment();
Self {
blobs: BTreeMap::new(),
prev: None,
max_read_size,
mode,
align,
}
}
@@ -545,7 +400,7 @@ impl VectoredReadPlanner {
}
pub fn finish(self) -> Vec<VectoredRead> {
let mut current_read_builder: Option<VectoredReadBuilder> = None;
let mut current_read_builder: Option<ChunkedVectoredReadBuilder> = None;
let mut reads = Vec::new();
for (key, blobs_for_key) in self.blobs {
@@ -558,12 +413,12 @@ impl VectoredReadPlanner {
};
if extended == VectoredReadExtended::No {
let next_read_builder = VectoredReadBuilder::new(
let next_read_builder = ChunkedVectoredReadBuilder::new(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.max_read_size,
self.mode,
self.align,
);
let prev_read_builder = current_read_builder.replace(next_read_builder);
@@ -688,7 +543,7 @@ impl<'a> VectoredBlobReader<'a> {
/// `handle` gets called and when the current key would just exceed the read_size and
/// max_cnt constraints.
pub struct StreamingVectoredReadPlanner {
read_builder: Option<VectoredReadBuilder>,
read_builder: Option<ChunkedVectoredReadBuilder>,
// Arguments for previous blob passed into [`StreamingVectoredReadPlanner::handle`]
prev: Option<(Key, Lsn, u64)>,
/// Max read size per batch. This is not a strict limit. If there are [0, 100) and [100, 200), while the `max_read_size` is 150,
@@ -699,21 +554,21 @@ pub struct StreamingVectoredReadPlanner {
/// Size of the current batch
cnt: usize,
mode: VectoredReadCoalesceMode,
align: usize,
}
impl StreamingVectoredReadPlanner {
pub fn new(max_read_size: u64, max_cnt: usize) -> Self {
assert!(max_cnt > 0);
assert!(max_read_size > 0);
let mode = VectoredReadCoalesceMode::get();
let align = virtual_file::get_io_buffer_alignment();
Self {
read_builder: None,
prev: None,
max_cnt,
max_read_size,
cnt: 0,
mode,
align,
}
}
@@ -762,11 +617,11 @@ impl StreamingVectoredReadPlanner {
}
None => {
self.read_builder = {
Some(VectoredReadBuilder::new_streaming(
Some(ChunkedVectoredReadBuilder::new_streaming(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.mode,
self.align,
))
};
}
@@ -1092,7 +947,7 @@ mod tests {
let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16;
let mut buf = BytesMut::with_capacity(reserved_bytes);
let mode = VectoredReadCoalesceMode::get();
let align = virtual_file::get_io_buffer_alignment();
let vectored_blob_reader = VectoredBlobReader::new(&file);
let meta = BlobMeta {
key: Key::MIN,
@@ -1104,7 +959,8 @@ mod tests {
if idx + 1 == offsets.len() {
continue;
}
let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, mode);
let read_builder =
ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align);
let read = read_builder.build();
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
assert_eq!(result.blobs.len(), 1);

View File

@@ -466,6 +466,7 @@ impl VirtualFile {
&[]
};
utils::crashsafe::overwrite(&final_path, &tmp_path, content)
.maybe_fatal_err("crashsafe_overwrite")
})
.await
.expect("blocking task is never aborted")
@@ -475,7 +476,7 @@ impl VirtualFile {
pub async fn sync_all(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_all(file_guard).await;
res
res.maybe_fatal_err("sync_all")
})
}
@@ -483,7 +484,7 @@ impl VirtualFile {
pub async fn sync_data(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_data(file_guard).await;
res
res.maybe_fatal_err("sync_data")
})
}
@@ -1147,7 +1148,9 @@ pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize)
panic!("virtual_file::init called twice");
}
if set_io_buffer_alignment(io_buffer_alignment).is_err() {
panic!("IO buffer alignment ({io_buffer_alignment}) is not a power of two");
panic!(
"IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}"
);
}
io_engine::init(engine);
crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64);
@@ -1174,14 +1177,16 @@ fn get_open_files() -> &'static OpenFiles {
static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT);
/// Returns true if `x` is zero or a power of two.
fn is_zero_or_power_of_two(x: usize) -> bool {
(x == 0) || ((x & (x - 1)) == 0)
/// Returns true if the alignment is a power of two and is greater or equal to 512.
fn is_valid_io_buffer_alignment(align: usize) -> bool {
align.is_power_of_two() && align >= 512
}
/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is
/// not a power of two or less than 512 bytes.
#[allow(unused)]
pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
if is_zero_or_power_of_two(align) {
if is_valid_io_buffer_alignment(align) {
IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed);
Ok(())
} else {
@@ -1189,19 +1194,19 @@ pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
}
}
/// Gets the io buffer alignment requirement. Returns 0 if there is no requirement specified.
/// Gets the io buffer alignment.
///
/// This function should be used to check the raw config value.
pub(crate) fn get_io_buffer_alignment_raw() -> usize {
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed);
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT";
if let Some(test_align) = utils::env::var(env_var_name) {
if is_zero_or_power_of_two(test_align) {
if is_valid_io_buffer_alignment(test_align) {
test_align
} else {
panic!("IO buffer alignment ({test_align}) is not a power of two");
panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}");
}
} else {
align
@@ -1211,14 +1216,6 @@ pub(crate) fn get_io_buffer_alignment_raw() -> usize {
}
}
/// Gets the io buffer alignment requirement. Returns 1 if the alignment config is set to zero.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
let align = get_io_buffer_alignment_raw();
align.max(1)
}
#[cfg(test)]
mod tests {
use crate::context::DownloadBehavior;

View File

@@ -1,8 +1,6 @@
# neon extension
comment = 'cloud storage for PostgreSQL'
# TODO: bump default version to 1.5, after we are certain that we don't
# need to rollback the compute image
default_version = '1.4'
default_version = '1.5'
module_pathname = '$libdir/neon'
relocatable = true
trusted = true

View File

@@ -1473,11 +1473,33 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count,
{
NeonWALReadResult res;
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
#if PG_MAJORVERSION_NUM >= 17
if (!sk->wp->config->syncSafekeepers)
{
Size rbytes;
rbytes = WALReadFromBuffers(buf, startptr, count,
walprop_pg_get_timeline_id());
startptr += rbytes;
count -= rbytes;
}
#endif
if (count == 0)
{
res = NEON_WALREAD_SUCCESS;
}
else
{
Assert(count > 0);
/* Now read the remaining WAL from the WAL file */
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
}
if (res == NEON_WALREAD_SUCCESS)
{

View File

@@ -24,6 +24,7 @@ bytes = { workspace = true, features = ["serde"] }
camino.workspace = true
chrono.workspace = true
clap.workspace = true
compute_api.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
@@ -81,7 +82,6 @@ tokio-postgres-rustls.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true

View File

@@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}
#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
}
impl std::fmt::Display for Backend<'_, (), ()> {
@@ -557,7 +565,7 @@ mod tests {
stream::{PqStream, Stream},
};
use super::{auth_quirks, AuthRateLimiter};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
struct Auth {
ips: Vec<IpPattern>,
@@ -585,6 +593,14 @@ mod tests {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
@@ -595,12 +611,15 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -1,5 +1,6 @@
use std::{
future::Future,
marker::PhantomData,
sync::Arc,
time::{Duration, SystemTime},
};
@@ -8,11 +9,14 @@ use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{Deserialize, Deserializer};
use serde::{de::Visitor, Deserialize, Deserializer};
use signature::Verifier;
use tokio::time::Instant;
use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
use crate::{
context::RequestMonitoring, http::parse_json_body_with_limit, intern::RoleNameInt, EndpointId,
RoleName,
};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
@@ -27,7 +31,6 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
}
@@ -35,10 +38,11 @@ pub(crate) struct AuthRule {
pub(crate) id: String,
pub(crate) jwks_url: url::Url,
pub(crate) audience: Option<String>,
pub(crate) role_names: Vec<RoleNameInt>,
}
#[derive(Default)]
pub(crate) struct JwkCache {
pub struct JwkCache {
client: reqwest::Client,
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
@@ -54,18 +58,28 @@ pub(crate) struct JwkCacheEntry {
}
impl JwkCacheEntry {
fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets.values().find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
fn find_jwk_and_audience(
&self,
key_id: &str,
role_name: &RoleName,
) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets
.values()
// make sure our requested role has access to the key set
.filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
// try and find the requested key-id in the key set
.find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
}
}
struct KeySet {
jwks: jose_jwk::JwkSet,
audience: Option<String>,
role_names: Vec<RoleNameInt>,
}
impl KeySet {
@@ -106,7 +120,6 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
auth_rules: &F,
) -> anyhow::Result<Arc<JwkCacheEntry>> {
// double check that no one beat us to updating the cache.
@@ -119,11 +132,10 @@ impl JwkCacheEntryLock {
}
}
let rules = auth_rules
.fetch_auth_rules(ctx, endpoint, role_name)
.await?;
let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
let mut key_sets =
ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
@@ -136,14 +148,15 @@ impl JwkCacheEntryLock {
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
match parse_json_body_with_limit::<jose_jwk::JwkSet, _>(
PhantomData,
resp.into_body(),
MAX_JWK_BODY_SIZE,
)
.await
{
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
tracing::warn!(url=?rule.jwks_url, error=%e, "could not decode JWKs");
}
Ok(jwks) => {
key_sets.insert(
@@ -151,6 +164,7 @@ impl JwkCacheEntryLock {
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
@@ -173,7 +187,6 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
let now = Instant::now();
@@ -183,9 +196,7 @@ impl JwkCacheEntryLock {
let Some(cached) = guard else {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
};
let last_update = now.duration_since(cached.last_retrieved);
@@ -196,9 +207,7 @@ impl JwkCacheEntryLock {
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
}
// every 5 minutes we should spawn a job to eagerly update the token.
@@ -212,7 +221,7 @@ impl JwkCacheEntryLock {
let ctx = ctx.clone();
tokio::spawn(async move {
if let Err(e) = entry
.renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
.renew_jwks(permit, &ctx, &client, endpoint, &fetch)
.await
{
tracing::warn!(error=?e, "could not fetch JWKs in background job");
@@ -232,7 +241,7 @@ impl JwkCacheEntryLock {
jwt: &str,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
role_name: &RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
@@ -254,30 +263,22 @@ impl JwkCacheEntryLock {
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
ensure!(header.typ == "JWT");
let kid = header.key_id.context("missing key id")?;
let mut guard = self
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
.await?;
// get the key from the JWKs if possible. If not, wait for the keys to update.
let (jwk, expected_audience) = loop {
match guard.find_jwk_and_audience(kid) {
match guard.find_jwk_and_audience(kid, role_name) {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
guard = self
.renew_jwks(
permit,
ctx,
client,
endpoint.clone(),
role_name.clone(),
fetch,
)
.renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
.await?;
}
_ => {
@@ -296,7 +297,7 @@ impl JwkCacheEntryLock {
verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
}
jose_jwk::Key::Rsa(key) => {
verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
}
key => bail!("unsupported key type {key:?}"),
};
@@ -308,23 +309,24 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
match (expected_audience, payload.audience) {
// check the audience matches
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
if let Some(aud) = expected_audience {
ensure!(
payload.audience.0.iter().any(|s| s == aud),
"invalid JWT token audience"
);
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY);
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
}
if let Some(nbf) = payload.not_before {
ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
ensure!(
nbf < now + CLOCK_SKEW_LEEWAY,
"JWT token is not yet ready to use"
);
}
Ok(())
@@ -336,7 +338,7 @@ impl JwkCache {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
role_name: &RoleName,
fetch: &F,
jwt: &str,
) -> Result<(), anyhow::Error> {
@@ -377,7 +379,7 @@ fn verify_rsa_signature(
data: &[u8],
sig: &[u8],
key: &jose_jwk::Rsa,
alg: &Option<jose_jwa::Algorithm>,
alg: &jose_jwa::Algorithm,
) -> anyhow::Result<()> {
use jose_jwa::{Algorithm, Signing};
use rsa::{
@@ -388,7 +390,7 @@ fn verify_rsa_signature(
let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
match alg {
Some(Algorithm::Signing(Signing::Rs256)) => {
Algorithm::Signing(Signing::Rs256) => {
let key = VerifyingKey::<sha2::Sha256>::new(key);
let sig = Signature::try_from(sig)?;
key.verify(data, &sig)?;
@@ -402,9 +404,6 @@ fn verify_rsa_signature(
/// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
#[derive(serde::Deserialize, serde::Serialize)]
struct JwtHeader<'a> {
/// must be "JWT"
#[serde(rename = "typ")]
typ: &'a str,
/// must be a supported alg
#[serde(rename = "alg")]
algorithm: jose_jwa::Algorithm,
@@ -414,11 +413,12 @@ struct JwtHeader<'a> {
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[derive(serde::Deserialize, Debug)]
#[allow(dead_code)]
struct JwtPayload<'a> {
/// Audience - Recipient for which the JWT is intended
#[serde(rename = "aud")]
audience: Option<&'a str>,
#[serde(rename = "aud", default)]
audience: OneOrMany,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
@@ -441,6 +441,59 @@ struct JwtPayload<'a> {
session_id: Option<&'a str>,
}
/// `OneOrMany` supports parsing either a single item or an array of items.
///
/// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
///
/// > The "aud" (audience) claim identifies the recipients that the JWT is
/// > intended for. Each principal intended to process the JWT MUST
/// > identify itself with a value in the audience claim. If the principal
/// > processing the claim does not identify itself with a value in the
/// > "aud" claim when this claim is present, then the JWT MUST be
/// > rejected. In the general case, the "aud" value is **an array of case-
/// > sensitive strings**, each containing a StringOrURI value. In the
/// > special case when the JWT has one audience, the "aud" value MAY be a
/// > **single case-sensitive string** containing a StringOrURI value. The
/// > interpretation of audience values is generally application specific.
/// > Use of this claim is OPTIONAL.
#[derive(Default, Debug)]
struct OneOrMany(Vec<String>);
impl<'de> Deserialize<'de> for OneOrMany {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct OneOrManyVisitor;
impl<'de> Visitor<'de> for OneOrManyVisitor {
type Value = OneOrMany;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(OneOrMany(vec![v.to_owned()]))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v = vec![];
while let Some(s) = seq.next_element()? {
v.push(s);
}
Ok(OneOrMany(v))
}
}
deserializer.deserialize_any(OneOrManyVisitor)
}
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
@@ -534,7 +587,6 @@ mod tests {
key: jose_jwk::Key::Ec(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
..Default::default()
},
};
@@ -548,7 +600,6 @@ mod tests {
key: jose_jwk::Key::Rsa(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
..Default::default()
},
};
@@ -557,7 +608,6 @@ mod tests {
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let header = JwtHeader {
typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid),
};
@@ -572,7 +622,7 @@ mod tests {
format!("{header}.{body}")
}
fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
@@ -660,11 +710,6 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), ec1);
let jwt4 = new_ec_jwt("4".into(), ec2);
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
@@ -706,47 +751,98 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
let client = reqwest::Client::new();
#[derive(Clone)]
struct Fetch(SocketAddr);
struct Fetch(SocketAddr, Vec<RoleNameInt>);
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
_role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
])
}
}
let role_name = RoleName::from("user");
let role_name1 = RoleName::from("anonymous");
let role_name2 = RoleName::from("authenticated");
let fetch = Fetch(
addr,
vec![
RoleNameInt::from(&role_name1),
RoleNameInt::from(&role_name2),
],
);
let endpoint = EndpointId::from("ep");
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
for token in [jwt1, jwt2, jwt3, jwt4] {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&token,
&client,
endpoint.clone(),
role_name.clone(),
&Fetch(addr),
)
.await
.unwrap();
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), &ec1);
let jwt4 = new_ec_jwt("4".into(), &ec2);
// had the wrong kid, therefore will have the wrong ecdsa signature
let bad_jwt = new_ec_jwt("3".into(), &ec2);
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&bad_jwt,
&client,
endpoint.clone(),
&role_name1,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("signature error"));
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&jwt1,
&client,
endpoint.clone(),
&bad_role_name,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("jwk not found"));
let tokens = [jwt1, jwt2, jwt3, jwt4];
let role_names = [role_name1, role_name2];
for role in &role_names {
for token in &tokens {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
token,
&client,
endpoint.clone(),
role,
&fetch,
)
.await
.unwrap();
}
}
}
}

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, net::SocketAddr};
use std::net::SocketAddr;
use anyhow::Context;
use arc_swap::ArcSwapOption;
@@ -10,21 +10,19 @@ use crate::{
NodeInfo,
},
context::RequestMonitoring,
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
EndpointId, RoleName,
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
EndpointId,
};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
use super::jwt::{AuthRule, FetchAuthRules};
pub struct LocalBackend {
pub(crate) jwks_cache: JwkCache,
pub(crate) node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
@@ -48,26 +46,17 @@ impl LocalBackend {
#[derive(Clone, Copy)]
pub(crate) struct StaticAuthRules;
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
#[derive(Debug, Clone)]
pub struct JwksRoleSettings {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
}
pub static JWKS_ROLE_MAP: ArcSwapOption<EndpointJwksResponse> = ArcSwapOption::const_empty();
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()
.and_then(|m| m.roles.get(&role_name))
.context("JWKs settings for this role were not configured")?;
let mut rules = vec![];
for setting in &role_mappings.jwks {
@@ -75,6 +64,7 @@ impl FetchAuthRules for StaticAuthRules {
id: setting.id.clone(),
jwks_url: setting.jwks_url.clone(),
audience: setting.jwt_audience.clone(),
role_names: setting.role_names.clone(),
});
}

View File

@@ -1,34 +1,38 @@
use std::{
net::SocketAddr,
path::{Path, PathBuf},
pin::pin,
sync::Arc,
time::Duration,
};
use std::{net::SocketAddr, pin::pin, str::FromStr, sync::Arc, time::Duration};
use anyhow::{bail, ensure};
use anyhow::{bail, ensure, Context};
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::{future::Either, FutureExt};
use futures::future::Either;
use proxy::{
auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{locks::ApiLocks, messages::JwksRoleMapping},
console::{
locks::ApiLocks,
messages::{EndpointJwksResponse, JwksSettings},
},
http::health_server::AppMetrics,
intern::RoleNameInt,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
RoleName,
};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::{net::TcpListener, task::JoinSet};
use tokio::{net::TcpListener, sync::Notify, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
use utils::{pid_file, project_build_tag, project_git_version, sentry_init::init_sentry};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
@@ -72,9 +76,12 @@ struct LocalProxyCliArgs {
/// Address of the postgres server
#[clap(long, default_value = "127.0.0.1:5432")]
compute: SocketAddr,
/// File address of the local proxy config file
/// Path of the local proxy config file
#[clap(long, default_value = "./localproxy.json")]
config_path: PathBuf,
config_path: Utf8PathBuf,
/// Path of the local proxy PID file
#[clap(long, default_value = "./localproxy.pid")]
pid_path: Utf8PathBuf,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -126,6 +133,24 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
//
// This also claims a "lock" that makes sure only one instance
// of local-proxy runs at a time.
let _process_guard = loop {
match pid_file::claim_for_current_process(&args.pid_path) {
Ok(guard) => break guard,
Err(e) => {
// compute-ctl might have tried to read the pid-file to let us
// know about some config change. We should try again.
error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
};
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
@@ -139,12 +164,30 @@ async fn main() -> anyhow::Result<()> {
16,
));
refresh_config(args.config_path.clone()).await;
// write the process ID to a file so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
let pid = std::process::id();
info!("process running in PID {pid}");
std::fs::write(args.pid_path, format!("{pid}\n")).context("writing PID to file")?;
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
refresh_config(args.config_path.clone()).map(Ok)
let refresh_config_notify = Arc::new(Notify::new());
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), {
let refresh_config_notify = Arc::clone(&refresh_config_notify);
move || {
refresh_config_notify.notify_one();
}
}));
// trigger the first config load **after** setting up the signal hook
// to avoid the race condition where:
// 1. No config file registered when local-proxy starts up
// 2. The config file is written but the signal hook is not yet received
// 3. local-proxy completes startup but has no config loaded, despite there being a registerd config.
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
maintenance_tasks.spawn(proxy::http::health_server::task_main(
metrics_listener,
AppMetrics {
@@ -227,14 +270,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
allow_self_signed_compute: false,
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
},
require_client_ip: false,
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
@@ -245,81 +291,84 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
async fn refresh_config(path: PathBuf) {
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
}
}
async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
async fn refresh_config_inner(path: &Utf8Path) -> anyhow::Result<()> {
let bytes = tokio::fs::read(&path).await?;
let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut settings = None;
let mut jwks_set = vec![];
for mapping in data.roles.values_mut() {
for jwks in &mut mapping.jwks {
ensure!(
jwks.jwks_url.has_authority()
&& (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
for jwks in data.jwks {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks.jwks_url
.host()
.is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
// clear username, password and ports
jwks.jwks_url.set_username("").expect(
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
jwks.jwks_url.set_password(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks.jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks.jwks_url.set_fragment(None);
jwks.jwks_url.query_pairs_mut().clear().finish();
if jwks.jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks.jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
ensure!(
*pr == jwks.project_id,
"inconsistent project IDs configured"
);
ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
jwks_set.push(JwksSettings {
id: jwks.id,
jwks_url,
provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
if let Some((project_id, branch_id)) = settings {
JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings {
roles: data.roles,
project_id,
branch_id,
})));
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
Ok(())
}

View File

@@ -133,9 +133,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async {
Ok(())
}));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || {}));
// the signal task cant ever succeed.
// the main task can error, or can succeed on cancellation.

View File

@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
@@ -17,6 +18,7 @@ use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
@@ -102,6 +104,9 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -144,9 +149,6 @@ struct ProxyCliArgs {
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
@@ -229,6 +231,11 @@ struct ProxyCliArgs {
/// Configure if this is a private access proxy for the POC: In that case the proxy will ignore the IP allowlist
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_private_access_proxy: bool,
/// Configure whether all incoming requests have a Proxy Protocol V2 packet.
// TODO(conradludgate): switch default to rejected or required once we've updated all deployments
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)]
proxy_protocol_v2: ProxyProtocolV2,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -290,6 +297,7 @@ async fn main() -> anyhow::Result<()> {
build_tag: BUILD_TAG,
});
proxy::jemalloc::inspect_thp()?;
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
@@ -382,9 +390,27 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default();
@@ -430,21 +456,17 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
@@ -461,10 +483,7 @@ async fn main() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(
cancellation_token.clone(),
|| async { Ok(()) },
));
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -677,7 +696,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: true,
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -692,12 +711,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
};
let config = Box::leak(Box::new(ProxyConfig {
@@ -707,7 +729,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
require_client_ip: args.require_client_ip,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,

View File

@@ -1,5 +1,8 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
console::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -7,6 +10,7 @@ use crate::{
Host,
};
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
@@ -30,7 +34,7 @@ pub struct ProxyConfig {
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
@@ -38,6 +42,16 @@ pub struct ProxyConfig {
pub connect_to_compute_retry_config: RetryConfig,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
pub enum ProxyProtocolV2 {
/// Connection will error if PROXY protocol v2 header is missing
Required,
/// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing.
Supported,
/// Connection will error if PROXY protocol v2 header is provided
Rejected,
}
#[derive(Debug)]
pub struct MetricCollectionConfig {
pub endpoint: reqwest::Url,
@@ -67,6 +81,9 @@ pub struct AuthenticationConfig {
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
}
impl TlsConfig {
@@ -250,18 +267,26 @@ impl CertResolver {
let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context("Failed to parse common name from certificate")?;
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));

View File

@@ -1,13 +1,11 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::proxy::retry::CouldRetry;
use crate::RoleName;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
@@ -348,11 +346,6 @@ impl ColdStartInfo {
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksRoleMapping {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct EndpointJwksResponse {
pub jwks: Vec<JwksSettings>,
@@ -361,11 +354,10 @@ pub struct EndpointJwksResponse {
#[derive(Debug, Deserialize, Clone)]
pub struct JwksSettings {
pub id: String,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
pub jwks_url: url::Url,
pub provider_name: String,
pub jwt_audience: Option<String>,
pub role_names: Vec<RoleNameInt>,
}
#[cfg(test)]

View File

@@ -5,7 +5,10 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
@@ -16,7 +19,7 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
@@ -334,6 +337,12 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
@@ -343,6 +352,7 @@ pub(crate) trait Api {
}
#[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
@@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -552,3 +576,13 @@ impl WakeComputePermit {
res
}
}
impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}

View File

@@ -4,7 +4,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::context::RequestMonitoring;
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
@@ -118,6 +120,39 @@ impl Api {
})
}
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}
drop(client);
connection.await??;
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
@@ -185,6 +220,14 @@ impl super::Api for Api {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -7,27 +7,33 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::ComputeUserInfo,
auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute,
console::messages::{ColdStartInfo, Reason},
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
jwt: String,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
}
impl Api {
@@ -38,7 +44,9 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
Self {
endpoint,
caches,
@@ -71,9 +79,9 @@ impl Api {
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
@@ -125,6 +133,61 @@ impl Api {
.await
}
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -135,7 +198,7 @@ impl Api {
async {
let mut request_builder = self
.endpoint
.get("proxy_wake_compute")
.get_path("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
@@ -262,6 +325,15 @@ impl super::Api for Api {
))
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -6,11 +6,10 @@ pub mod health_server;
use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper1::body::Body;
use serde::de::DeserializeOwned;
use serde::de::DeserializeSeed;
pub(crate) use reqwest::{Request, Response};
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
@@ -86,9 +85,17 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
let mut url = self.endpoint.clone();
url.path_segments_mut().push(path);
f(&mut url);
self.client.get(url.into_inner())
}
@@ -105,10 +112,21 @@ impl Endpoint {
}
}
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError<E> {
#[error("could not read the HTTP body: {0}")]
Read(E),
#[error("could not parse the HTTP body: {0}")]
Parse(#[from] serde_json::Error),
#[error("could not parse the HTTP body: content length exceeds limit of {0} bytes")]
LengthExceeded(usize),
}
pub(crate) async fn parse_json_body_with_limit<D, E>(
seed: impl for<'de> DeserializeSeed<'de, Value = D>,
mut b: impl Body<Data = Bytes, Error = E> + Unpin,
limit: usize,
) -> anyhow::Result<D> {
) -> Result<D, ReadPayloadError<E>> {
// We could use `b.limited().collect().await.to_bytes()` here
// but this ends up being slightly more efficient as far as I can tell.
@@ -116,20 +134,25 @@ pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
// in reqwest, this value is influenced by the Content-Length header.
let lower_bound = match usize::try_from(b.size_hint().lower()) {
Ok(bound) if bound <= limit => bound,
_ => bail!("Content length exceeds limit of {limit} bytes"),
_ => return Err(ReadPayloadError::LengthExceeded(limit)),
};
let mut bytes = Vec::with_capacity(lower_bound);
while let Some(frame) = b.frame().await.transpose()? {
while let Some(frame) = b
.frame()
.await
.transpose()
.map_err(ReadPayloadError::Read)?
{
if let Ok(data) = frame.into_data() {
if bytes.len() + data.len() > limit {
bail!("Content length exceeds limit of {limit} bytes")
return Err(ReadPayloadError::LengthExceeded(limit));
}
bytes.extend_from_slice(&data);
}
}
Ok(serde_json::from_slice::<D>(&bytes)?)
Ok(seed.deserialize(&mut serde_json::Deserializer::from_slice(&bytes))?)
}
#[cfg(test)]
@@ -144,7 +167,7 @@ mod tests {
// Validate that this pattern makes sense.
let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[
("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all
@@ -162,7 +185,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new());
let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())])
.build()?;

View File

@@ -130,14 +130,14 @@ impl<Id: InternId> Default for StringInterner<Id> {
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct RoleNameTag;
pub struct RoleNameTag;
impl InternId for RoleNameTag {
fn get_interner() -> &'static StringInterner<Self> {
static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
pub type RoleNameInt = InternedString<RoleNameTag>;
impl From<&RoleName> for RoleNameInt {
fn from(value: &RoleName) -> Self {
RoleNameTag::get_interner().get_or_intern(value)

View File

@@ -9,7 +9,8 @@ use measured::{
text::TextEncoder,
LabelGroup, MetricGroup,
};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version, Access, AsName, Name};
use tracing::info;
pub struct MetricRecorder {
epoch: epoch_mib,
@@ -114,3 +115,10 @@ jemalloc_gauge!(mapped, mapped_mib);
jemalloc_gauge!(metadata, metadata_mib);
jemalloc_gauge!(resident, resident_mib);
jemalloc_gauge!(retained, retained_mib);
pub fn inspect_thp() -> Result<(), tikv_jemalloc_ctl::Error> {
let opt_thp: &Name = c"opt.thp".to_bytes_with_nul().name();
let s: &str = opt_thp.read()?;
info!("jemalloc opt.thp {s}");
Ok(())
}

View File

@@ -82,7 +82,7 @@
impl_trait_overcaptures,
)]
use std::{convert::Infallible, future::Future};
use std::convert::Infallible;
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
@@ -117,13 +117,12 @@ pub mod usage_metrics;
pub mod waiters;
/// Handle unix signals appropriately.
pub async fn handle_signals<F, Fut>(
pub async fn handle_signals<F>(
token: CancellationToken,
mut refresh_config: F,
) -> anyhow::Result<Infallible>
where
F: FnMut() -> Fut,
Fut: Future<Output = anyhow::Result<()>>,
F: FnMut(),
{
use tokio::signal::unix::{signal, SignalKind};
@@ -136,7 +135,7 @@ where
// Hangup is commonly used for config reload.
_ = hangup.recv() => {
warn!("received SIGHUP");
refresh_config().await?;
refresh_config();
}
// Shut down the whole application.
_ = interrupt.recv() => {

View File

@@ -1,4 +1,3 @@
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::{
filter::{EnvFilter, LevelFilter},
prelude::*,
@@ -23,9 +22,7 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
.with_writer(std::io::stderr)
.with_target(false);
let otlp_layer = tracing_utils::init_tracing("proxy")
.await
.map(OpenTelemetryLayer::new);
let otlp_layer = tracing_utils::init_tracing("proxy").await;
tracing_subscriber::registry()
.with(env_filter)

View File

@@ -10,6 +10,7 @@ pub(crate) mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use crate::config::ProxyProtocolV2;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -93,15 +94,19 @@ pub async fn task_main(
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Ok((socket, Some(addr))) => (socket, addr.ip()),
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Ok((_socket, None)) if config.require_client_ip => {
error!("missing required client IP");
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
};

View File

@@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {

View File

@@ -43,6 +43,13 @@ impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
// rayon would be nice here, but yielding in rayon does not work well afaict.
if n_workers == 0 {
return Arc::new(Self {
runtime: None,
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
}
Arc::new_cyclic(|pool| {
let pool = pool.clone();
let worker_id = AtomicUsize::new(0);

View File

@@ -5,6 +5,7 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod sql_over_http;
@@ -19,7 +20,8 @@ use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
@@ -81,7 +83,28 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -342,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -386,7 +409,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -409,7 +432,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new()))
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,6 +1,8 @@
use std::{sync::Arc, time::Duration};
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use crate::{
@@ -27,9 +29,13 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -103,32 +109,44 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
jwt: String,
) -> Result<(), AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::Console(_, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
}
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(cache) => {
cache
crate::auth::Backend::Console(console, ()) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
user_info.user.clone(),
&StaticAuthRules,
jwt,
&user_info.user,
&**console,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
Ok(())
}
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&StaticAuthRules,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
}
}
}
@@ -174,14 +192,55 @@ impl PoolingBackend {
)
.await
}
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to postgres in compute")]
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::ConnectionError(e) => e.could_retry(),
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
}
impl UserFacingError for LocalProxyConnError {
fn to_string_client(&self) -> String {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
// assumption: host is an ip address so this should not actually perform any requests.
// todo: add that assumption as a guarantee in the control-plane API.
let mut addrs = lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -0,0 +1,342 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
// or we can change this to an Option rather than a VecDeque.
//
// Opening more connections to the same db because we run out of streams
// seems somewhat redundant though.
//
// Probably we should run a semaphore and just the single conn. TBD.
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
loop {
let conn = conns.pop_front()?;
if !conn.conn.is_closed() {
conns.push_back(conn.clone());
return Some(conn);
}
}
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::Full;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,17 +64,24 @@ struct HttpErrorBody {
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.unwrap()
}
}
@@ -83,14 +90,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -1,18 +1,534 @@
use std::fmt;
use std::marker::PhantomData;
use std::ops::Range;
use itertools::Itertools;
use serde::de;
use serde::de::DeserializeSeed;
use serde::Deserialize;
use serde::Deserializer;
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::Row;
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
use super::sql_over_http::BatchQueryData;
use super::sql_over_http::Payload;
use super::sql_over_http::QueryData;
#[derive(Clone, Copy)]
pub struct Slice {
pub start: u32,
pub len: u32,
}
fn json_value_to_pg_text(value: &Value) -> Option<String> {
impl Slice {
pub fn into_range(self) -> Range<usize> {
let start = self.start as usize;
let end = start + self.len as usize;
start..end
}
}
#[derive(Default)]
pub struct Arena {
pub str_arena: String,
pub params_arena: Vec<Option<Slice>>,
}
impl Arena {
fn alloc_str(&mut self, s: &str) -> Slice {
let start = self.str_arena.len() as u32;
let len = s.len() as u32;
self.str_arena.push_str(s);
Slice { start, len }
}
}
pub struct SerdeArena<'a, T> {
pub arena: &'a mut Arena,
pub _t: PhantomData<T>,
}
impl<'a, T> SerdeArena<'a, T> {
fn alloc_str(&mut self, s: &str) -> Slice {
self.arena.alloc_str(s)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Vec<QueryData>> {
type Value = Vec<QueryData>;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor<'a>(SerdeArena<'a, Vec<QueryData>>);
impl<'a, 'de> de::Visitor<'de> for VecVisitor<'a> {
type Value = Vec<QueryData>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut values = Vec::new();
while let Some(value) = seq.next_element_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<QueryData>,
})? {
values.push(value);
}
Ok(values)
}
}
d.deserialize_seq(VecVisitor(self))
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Slice> {
type Value = Slice;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, Slice>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Slice;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(mut self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(self.0.alloc_str(v))
}
}
d.deserialize_str(Visitor(self))
}
}
enum States {
Empty,
HasQueries(Vec<QueryData>),
HasPartialQueryData {
query: Option<Slice>,
params: Option<Slice>,
#[allow(clippy::option_option)]
array_mode: Option<Option<bool>>,
},
}
enum Field {
Queries,
Query,
Params,
ArrayMode,
Ignore,
}
struct FieldVisitor;
impl<'de> de::Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#,
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(v.as_bytes())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
match v {
b"queries" => Ok(Field::Queries),
b"query" => Ok(Field::Query),
b"params" => Ok(Field::Params),
b"arrayMode" => Ok(Field::ArrayMode),
_ => Ok(Field::Ignore),
}
}
}
impl<'de> Deserialize<'de> for Field {
#[inline]
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
d.deserialize_identifier(FieldVisitor)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, QueryData> {
type Value = QueryData;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, QueryData>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = QueryData;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Query => {
let (params, array_mode) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
}
Field::Params => {
let (query, array_mode) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(Slice { start, len }),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Queries | Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
}),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(
d,
"QueryData",
&["query", "params", "arrayMode"],
Visitor(self),
)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Payload> {
type Value = Payload;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, Payload>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Payload;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Queries => match state {
States::Empty => {
state = States::HasQueries(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Vec<QueryData>>,
})?);
}
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::duplicate_field("queries"))
}
States::HasPartialQueryData { .. } => {
return Err(<A::Error as de::Error>::unknown_field(
"queries",
&["query", "params", "arrayMode"],
))
}
},
Field::Query => {
let (params, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"query",
&["queries"],
))
}
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
}
Field::Params => {
let (query, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"params",
&["queries"],
))
}
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(Slice { start, len }),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"arrayMode",
&["queries"],
))
}
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })),
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(Payload::Single(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
})),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(
d,
"Payload",
&["queries", "query", "params", "arrayMode"],
Visitor(self),
)
}
}
struct PgText {
value: Vec<Option<String>>,
}
impl<'de> Deserialize<'de> for PgText {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor;
impl<'de> de::Visitor<'de> for VecVisitor {
type Value = Vec<Option<String>>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of postgres parameters")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Vec<Option<String>>, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut values = Vec::new();
// TODO: consider avoiding the allocations for json::Value here.
while let Some(value) = seq.next_element()? {
values.push(json_value_to_pg_text(value));
}
Ok(values)
}
}
let value = d.deserialize_seq(VecVisitor)?;
Ok(PgText { value })
}
}
fn json_value_to_pg_text(value: Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
@@ -21,10 +537,10 @@ fn json_value_to_pg_text(value: &Value) -> Option<String> {
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
Value::String(s) => Some(s),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
Value::Array(arr) => Some(json_array_to_pg_array(arr)),
}
}
@@ -36,7 +552,17 @@ fn json_value_to_pg_text(value: &Value) -> Option<String> {
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
fn json_array_to_pg_array(arr: Vec<Value>) -> String {
let vals = arr
.into_iter()
.map(json_array_value_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.join(",");
format!("{{{vals}}}")
}
fn json_array_value_to_pg_array(value: Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
@@ -44,19 +570,10 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
v @ Value::Object(_) => json_array_value_to_pg_array(Value::String(v.to_string())),
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{vals}}}"))
}
Value::Array(arr) => Some(json_array_to_pg_array(arr)),
}
}
@@ -261,24 +778,22 @@ mod tests {
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let pg_params = json_value_to_pg_text(Value::Bool(true));
assert_eq!(pg_params, Some("true".to_owned()));
let pg_params = json_value_to_pg_text(Value::Bool(false));
assert_eq!(pg_params, Some("false".to_owned()));
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = Value::Number(serde_json::Number::from(42));
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, Some("42".to_owned()));
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = Value::String("foo\"".to_string());
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, Some("foo\"".to_owned()));
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![None]);
let json = Value::Null;
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, None);
}
#[test]
@@ -286,31 +801,27 @@ mod tests {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_value_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some(
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
)]
Some("{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned())
);
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_value_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some(
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
Some("{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned())
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_value_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())
);
}

View File

@@ -1,3 +1,4 @@
use std::marker::PhantomData;
use std::pin::pin;
use std::sync::Arc;
@@ -8,6 +9,8 @@ use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -19,8 +22,6 @@ use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
@@ -38,52 +39,51 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::http::parse_json_body_with_limit;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::serverless::json::Arena;
use crate::serverless::json::SerdeArena;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::HttpConnError;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::json::Slice;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
pub(crate) struct QueryData {
pub(crate) query: Slice,
pub(crate) params: Slice,
pub(crate) array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
struct BatchQueryData {
queries: Vec<QueryData>,
pub(crate) struct BatchQueryData {
pub(crate) queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
pub(crate) enum Payload {
Single(QueryData),
Batch(BatchQueryData),
}
@@ -98,15 +98,6 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
@@ -123,8 +114,8 @@ pub(crate) enum ConnInfoError {
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing password")]
MissingPassword,
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
@@ -133,6 +124,14 @@ pub(crate) enum ConnInfoError {
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
@@ -146,6 +145,7 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
@@ -181,21 +181,32 @@ fn get_conn_info(
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingPassword)?
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingPassword);
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint = match connection_url.host() {
@@ -247,7 +258,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -279,7 +290,7 @@ pub(crate) async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -356,7 +367,7 @@ pub(crate) async fn handle(
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpError {
#[error("{0}")]
ReadPayload(#[from] ReadPayloadError),
ReadPayload(ReadPayloadError),
#[error("{0}")]
ConnectCompute(#[from] HttpConnError),
#[error("{0}")]
@@ -410,9 +421,9 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper1::Error),
Read(hyper1::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
Parse(serde_json::Error),
}
impl ReportableError for ReadPayloadError {
@@ -424,6 +435,18 @@ impl ReportableError for ReadPayloadError {
}
}
impl From<crate::http::ReadPayloadError<hyper1::Error>> for SqlOverHttpError {
fn from(value: crate::http::ReadPayloadError<hyper1::Error>) -> Self {
match value {
crate::http::ReadPayloadError::Read(e) => Self::ReadPayload(ReadPayloadError::Read(e)),
crate::http::ReadPayloadError::Parse(e) => {
Self::ReadPayload(ReadPayloadError::Parse(e))
}
crate::http::ReadPayloadError::LengthExceeded(x) => Self::RequestTooLarge(x as u64),
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -504,7 +527,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -514,26 +537,58 @@ async fn handle_inner(
"handling interactive connection from client"
);
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
auth,
backend,
)
.await
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let (parts, body) = request.into_parts();
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|| parts.headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
let parsed_headers = HttpHeaders::try_parse(headers)?;
let parsed_headers = HttpHeaders::try_parse(&parts.headers)?;
let request_content_length = match request.body().size_hint().upper() {
let request_content_length = match body.size_hint().upper() {
Some(v) => v,
None => config.http_config.max_request_size_bytes + 1,
};
@@ -551,38 +606,53 @@ async fn handle_inner(
));
}
let fetch_and_process_request = Box::pin(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
);
let fetch_and_process_request = Box::pin(async move {
let mut arena = Arena::default();
let seed = SerdeArena {
arena: &mut arena,
_t: PhantomData::<Payload>,
};
let payload = parse_json_body_with_limit(
seed,
body,
config.http_config.max_request_size_bytes as usize,
)
.await?;
Ok::<(Arena, Payload), SqlOverHttpError>((arena, payload)) // Adjust error type accordingly
});
let authenticate_and_connect = Box::pin(
async {
let keys = match &conn_info.auth {
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
pw,
&conn_info.user_info,
&pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
.await?
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
}
};
let client = backend
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
@@ -592,7 +662,7 @@ async fn handle_inner(
.map_err(SqlOverHttpError::from),
);
let (payload, mut client) = match run_until_cancelled(
let ((mut arena, payload), mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
@@ -606,6 +676,9 @@ async fn handle_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
arena.params_arena.shrink_to_fit();
arena.str_arena.shrink_to_fit();
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
@@ -613,7 +686,7 @@ async fn handle_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
stmt.process(config, &arena, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -631,16 +704,27 @@ async fn handle_inner(
}
statements
.process(config, cancel, &mut client, parsed_headers)
.process(config, &arena, cancel, &mut client, parsed_headers)
.await?
}
};
info!(
str_len = arena.str_arena.len(),
params = arena.params_arena.len(),
response = json_output.len(),
"data size"
);
let metrics = client.metrics();
let len = json_output.len();
let response = response
.body(Full::new(Bytes::from(json_output)))
.body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -656,10 +740,70 @@ async fn handle_inner(
Ok(response)
}
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&AUTHORIZATION,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (mut parts, body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
// these instead just to ensure they remain normalised.
for &h in HEADERS_TO_FORWARD {
if let Some(hv) = parts.headers.remove(h) {
req = req.header(h, hv);
}
}
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.inner
.send_request(req)
.await
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -668,7 +812,14 @@ impl QueryData {
let cancel_token = inner.cancel_token();
let res = match select(
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
pin!(query_to_json(
config,
arena,
&*inner,
self,
&mut 0,
parsed_headers
)),
pin!(cancel.cancelled()),
)
.await
@@ -676,10 +827,7 @@ impl QueryData {
// The query successfully completed.
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
discard.check_idle(status);
let json_output =
serde_json::to_string(&results).expect("json serialization should not fail");
Ok(json_output)
Ok(results)
}
// The query failed with an error
Either::Left((Err(e), __not_yet_cancelled)) => {
@@ -705,7 +853,9 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -732,6 +882,7 @@ impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -758,6 +909,7 @@ impl BatchQueryData {
let json_output = match query_batch(
config,
arena,
cancel.child_token(),
&transaction,
self,
@@ -802,16 +954,20 @@ impl BatchQueryData {
async fn query_batch(
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut comma = false;
let mut results = r#"{"results":["#.to_string();
let mut current_size = 0;
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
arena,
transaction,
stmt,
&mut current_size,
@@ -822,7 +978,11 @@ async fn query_batch(
match res {
// TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok((_, values)), _cancelled)) => {
results.push(values);
if comma {
results.push(',');
}
results.push_str(&values);
comma = true;
}
Either::Left((Err(e), _cancelled)) => {
return Err(e);
@@ -833,22 +993,28 @@ async fn query_batch(
}
}
let results = json!({ "results": results });
let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
results.push_str("]}");
Ok(json_output)
Ok(results)
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
arena: &Arena,
client: &T,
data: QueryData,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
) -> Result<(ReadyForQueryStatus, String), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
let query_params = arena.params_arena[data.params.into_range()]
.iter()
.map(|p| p.map(|p| &arena.str_arena[p.into_range()]));
let query = &arena.str_arena[data.query.into_range()];
let mut row_stream = std::pin::pin!(client.query_raw_txt(query, query_params).await?);
info!("finished executing query");
// Manually drain the stream into a vector to leave row_stream hanging
@@ -917,12 +1083,13 @@ async fn query_to_json<T: GenericClient>(
// Resulting JSON format is based on the format of node-postgres result.
let results = json!({
"command": command_tag_name.to_string(),
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
});
})
.to_string();
Ok((ready, results))
}

View File

@@ -374,14 +374,16 @@ type JoinTaskRes = Result<anyhow::Result<()>, JoinError>;
async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
// fsync the datadir to make sure we have a consistent state on disk.
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
if !conf.no_sync {
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
}
info!("starting safekeeper WAL service on {}", conf.listen_pg_addr);
let pg_listener = tcp_listener::bind(conf.listen_pg_addr.clone()).map_err(|e| {

View File

@@ -9,7 +9,7 @@ use crate::walproposer_sim::{
pub mod walproposer_sim;
// Generates 2000 random seeds and runs a schedule for each of them.
// Generates 500 random seeds and runs a schedule for each of them.
// If you see this test fail, please report the last seed to the
// @safekeeper team.
#[test]
@@ -17,7 +17,7 @@ fn test_random_schedules() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
for _ in 0..2000 {
for _ in 0..500 {
let seed: u64 = rand::thread_rng().gen();
config.network = generate_network_opts(seed);

View File

@@ -572,30 +572,7 @@ impl Reconciler {
// During a live migration it is unhelpful to proceed if we couldn't notify compute: if we detach
// the origin without notifying compute, we will render the tenant unavailable.
let mut notify_attempts = 0;
while let Err(e) = self.compute_notify().await {
match e {
NotifyError::Fatal(_) => return Err(ReconcileError::Notify(e)),
NotifyError::ShuttingDown => return Err(ReconcileError::Cancel),
_ => {
tracing::warn!(
"Live migration blocked by compute notification error, retrying: {e}"
);
}
}
exponential_backoff(
notify_attempts,
// Generous waits: control plane operations which might be blocking us usually complete on the order
// of hundreds to thousands of milliseconds, so no point busy polling.
1.0,
10.0,
&self.cancel,
)
.await;
notify_attempts += 1;
}
self.compute_notify_blocking(&origin_ps).await?;
pausable_failpoint!("reconciler-live-migrate-post-notify");
// Downgrade the origin to secondary. If the tenant's policy is PlacementPolicy::Attached(0), then
@@ -869,6 +846,117 @@ impl Reconciler {
Ok(())
}
}
/// Keep trying to notify the compute indefinitely, only dropping out if:
/// - the node `origin` becomes unavailable -> Ok(())
/// - the node `origin` no longer has our tenant shard attached -> Ok(())
/// - our cancellation token fires -> Err(ReconcileError::Cancelled)
///
/// This is used during live migration, where we do not wish to detach
/// an origin location until the compute definitely knows about the new
/// location.
///
/// In cases where the origin node becomes unavailable, we return success, indicating
/// to the caller that they should continue irrespective of whether the compute was notified,
/// because the origin node is unusable anyway. Notification will be retried later via the
/// [`Self::compute_notify_failure`] flag.
async fn compute_notify_blocking(&mut self, origin: &Node) -> Result<(), ReconcileError> {
let mut notify_attempts = 0;
while let Err(e) = self.compute_notify().await {
match e {
NotifyError::Fatal(_) => return Err(ReconcileError::Notify(e)),
NotifyError::ShuttingDown => return Err(ReconcileError::Cancel),
_ => {
tracing::warn!(
"Live migration blocked by compute notification error, retrying: {e}"
);
}
}
// Did the origin pageserver become unavailable?
if !origin.is_available() {
tracing::info!("Giving up on compute notification because {origin} is unavailable");
break;
}
// Does the origin pageserver still host the shard we are interested in? We should only
// continue waiting for compute notification to be acked if the old location is still usable.
let tenant_shard_id = self.tenant_shard_id;
match origin
.with_client_retries(
|client| async move { client.get_location_config(tenant_shard_id).await },
&self.service_config.jwt_token,
1,
3,
Duration::from_secs(5),
&self.cancel,
)
.await
{
Some(Ok(Some(location_conf))) => {
if matches!(
location_conf.mode,
LocationConfigMode::AttachedMulti
| LocationConfigMode::AttachedSingle
| LocationConfigMode::AttachedStale
) {
tracing::debug!(
"Still attached to {origin}, will wait & retry compute notification"
);
} else {
tracing::info!(
"Giving up on compute notification because {origin} is in state {:?}",
location_conf.mode
);
return Ok(());
}
// Fall through
}
Some(Ok(None)) => {
tracing::info!(
"No longer attached to {origin}, giving up on compute notification"
);
return Ok(());
}
Some(Err(e)) => {
match e {
mgmt_api::Error::Cancelled => {
tracing::info!(
"Giving up on compute notification because {origin} is unavailable"
);
return Ok(());
}
mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _) => {
tracing::info!(
"No longer attached to {origin}, giving up on compute notification"
);
return Ok(());
}
e => {
// Other API errors are unexpected here.
tracing::warn!("Unexpected error checking location on {origin}: {e}");
// Fall through, we will retry compute notification.
}
}
}
None => return Err(ReconcileError::Cancel),
};
exponential_backoff(
notify_attempts,
// Generous waits: control plane operations which might be blocking us usually complete on the order
// of hundreds to thousands of milliseconds, so no point busy polling.
1.0,
10.0,
&self.cancel,
)
.await;
notify_attempts += 1;
}
Ok(())
}
}
/// We tweak the externally-set TenantConfig while configuring

View File

@@ -4974,7 +4974,12 @@ impl Service {
{
let mut nodes_mut = (**nodes).clone();
nodes_mut.remove(&node_id);
if let Some(mut removed_node) = nodes_mut.remove(&node_id) {
// Ensure that any reconciler holding an Arc<> to this node will
// drop out when trying to RPC to it (setting Offline state sets the
// cancellation token on the Node object).
removed_node.set_availability(NodeAvailability::Offline);
}
*nodes = Arc::new(nodes_mut);
}
}

View File

@@ -4,7 +4,7 @@ use std::time::Duration;
use crate::checks::{list_timeline_blobs, BlobDataParseResult};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, MAX_RETRIES};
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path};
@@ -18,6 +18,7 @@ use serde::Serialize;
use storage_controller_client::control_api;
use tokio_util::sync::CancellationToken;
use tracing::{info_span, Instrument};
use utils::backoff;
use utils::generation::Generation;
use utils::id::{TenantId, TenantTimelineId};
@@ -326,15 +327,25 @@ async fn maybe_delete_index(
}
// All validations passed: erase the object
match remote_client
.delete(&obj.key, &CancellationToken::new())
.await
let cancel = CancellationToken::new();
match backoff::retry(
|| remote_client.delete(&obj.key, &cancel),
|_| false,
3,
MAX_RETRIES as u32,
"maybe_delete_index",
&cancel,
)
.await
{
Ok(_) => {
None => {
unreachable!("Using a dummy cancellation token");
}
Some(Ok(_)) => {
tracing::info!("Successfully deleted index");
summary.indices_deleted += 1;
}
Err(e) => {
Some(Err(e)) => {
tracing::warn!("Failed to delete index: {e}");
summary.remote_storage_errors += 1;
}

View File

@@ -950,9 +950,6 @@ class NeonEnv:
safekeepers - An array containing objects representing the safekeepers
pg_bin - pg_bin.run() can be used to execute Postgres client binaries,
like psql or pg_dump
initial_tenant - tenant ID of the initial tenant created in the repository
neon_cli - can be used to run the 'neon' CLI tool
@@ -3300,6 +3297,8 @@ class PgBin:
@pytest.fixture(scope="function")
def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) -> PgBin:
"""pg_bin.run() can be used to execute Postgres client binaries, like psql or pg_dump"""
return PgBin(test_output_dir, pg_distrib_dir, pg_version)
@@ -3311,7 +3310,7 @@ class VanillaPostgres(PgProtocol):
self.pg_bin = pg_bin
self.running = False
if init:
self.pg_bin.run_capture(["initdb", "-D", str(pgdatadir)])
self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)])
self.configure([f"port = {port}\n"])
def enable_tls(self):

View File

@@ -8,6 +8,7 @@ import requests
from fixtures.common_types import Lsn, TenantId, TenantTimelineId, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.utils import wait_until
# Walreceiver as returned by sk's timeline status endpoint.
@@ -161,6 +162,16 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter):
walreceivers=walreceivers,
)
# Get timeline_start_lsn, waiting until it's nonzero. It is a way to ensure
# that the timeline is fully initialized at the safekeeper.
def get_non_zero_timeline_start_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
def timeline_start_lsn_non_zero() -> Lsn:
s = self.timeline_status(tenant_id, timeline_id).timeline_start_lsn
assert s > Lsn(0)
return s
return wait_until(30, 1, timeline_start_lsn_non_zero)
def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
return self.timeline_status(tenant_id, timeline_id).commit_lsn

View File

@@ -53,7 +53,7 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
env = neon_simple_env
pageserver_http_client = env.pageserver.http_client()
tenant, _ = env.neon_cli.create_tenant(
tenant, timeline_main = env.neon_cli.create_tenant(
conf={
# disable background GC
"gc_period": "0s",
@@ -70,8 +70,7 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
}
)
timeline_main = env.neon_cli.create_timeline("test_main", tenant_id=tenant)
endpoint_main = env.endpoints.create_start("test_main", tenant_id=tenant)
endpoint_main = env.endpoints.create_start("main", tenant_id=tenant)
main_cur = endpoint_main.connect().cursor()
@@ -92,7 +91,7 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
pageserver_http_client.timeline_gc(tenant, timeline_main, lsn2 - lsn1 + 1024)
env.neon_cli.create_branch(
"test_branch", "test_main", tenant_id=tenant, ancestor_start_lsn=lsn1
"test_branch", ancestor_branch_name="main", ancestor_start_lsn=lsn1, tenant_id=tenant
)
endpoint_branch = env.endpoints.create_start("test_branch", tenant_id=tenant)

View File

@@ -21,7 +21,7 @@ from fixtures.pageserver.http import PageserverApiException
from fixtures.pageserver.utils import (
timeline_delete_wait_completed,
)
from fixtures.pg_version import PgVersion, skip_on_postgres
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage
from fixtures.workload import Workload
@@ -156,9 +156,6 @@ ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.order(after="test_create_snapshot")
@skip_on_postgres(
PgVersion.V17, "There are no snapshots yet"
) # TODO: revert this once we have snapshots
def test_backward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
@@ -206,9 +203,6 @@ def test_backward_compatibility(
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.order(after="test_create_snapshot")
@skip_on_postgres(
PgVersion.V17, "There are no snapshots yet"
) # TODO: revert this once we have snapshots
def test_forward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
@@ -258,7 +252,7 @@ def test_forward_compatibility(
# not using env.pageserver.version because it was initialized before
prev_pageserver_version_str = env.get_binary_version("pageserver")
prev_pageserver_version_match = re.search(
"Neon page server git-env:(.*) failpoints: (.*), features: (.*)",
"Neon page server git(?:-env)?:(.*) failpoints: (.*), features: (.*)",
prev_pageserver_version_str,
)
if prev_pageserver_version_match is not None:
@@ -269,12 +263,12 @@ def test_forward_compatibility(
)
# does not include logs from previous runs
assert not env.pageserver.log_contains("git-env:" + prev_pageserver_version)
assert not env.pageserver.log_contains(f"git(-env)?:{prev_pageserver_version}")
env.start()
# ensure the specified pageserver is running
assert env.pageserver.log_contains("git-env:" + prev_pageserver_version)
assert env.pageserver.log_contains(f"git(-env)?:{prev_pageserver_version}")
check_neon_works(
env,

View File

@@ -31,9 +31,7 @@ def helper_compare_timeline_list(
)
)
timelines_cli = env.neon_cli.list_timelines()
assert timelines_cli == env.neon_cli.list_timelines(initial_tenant)
timelines_cli = env.neon_cli.list_timelines(initial_tenant)
cli_timeline_ids = sorted([timeline_id for (_, timeline_id) in timelines_cli])
assert timelines_api == cli_timeline_ids

View File

@@ -24,7 +24,7 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder):
# IMPORTANT:
# If the version has changed, the test should be updated.
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.4",)
assert cur.fetchone() == ("1.5",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
res = cur.fetchall()
log.info(res)
@@ -48,7 +48,7 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
# IMPORTANT:
# If the version has changed, the test should be updated.
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.4",)
assert cur.fetchone() == ("1.5",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
all_versions = ["1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
current_version = "1.5"

View File

@@ -549,6 +549,14 @@ def test_multi_attach(
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
# Instruct the storage controller to not interfere with our low level configuration
# of the pageserver's attachment states. Otherwise when it sees nodes go offline+return,
# it would send its own requests that would conflict with the test's.
env.storage_controller.tenant_policy_update(tenant_id, {"scheduling": "Stop"})
env.storage_controller.allowed_errors.extend(
[".*Scheduling is disabled by policy Stop.*", ".*Skipping reconcile for policy Stop.*"]
)
# Initially, the tenant will be attached to the first pageserver (first is default in our test harness)
wait_until(10, 0.2, lambda: assert_tenant_state(http_clients[0], tenant_id, "Active"))
_detail = http_clients[0].timeline_detail(tenant_id, timeline_id)

View File

@@ -174,8 +174,7 @@ def test_pageserver_chaos(
"checkpoint_distance": "5000000",
}
)
env.neon_cli.create_timeline("test_pageserver_chaos", tenant_id=tenant)
endpoint = env.endpoints.create_start("test_pageserver_chaos", tenant_id=tenant)
endpoint = env.endpoints.create_start("main", tenant_id=tenant)
# Create table, and insert some rows. Make it big enough that it doesn't fit in
# shared_buffers, otherwise the SELECT after restart will just return answer

View File

@@ -27,7 +27,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
env.pageserver.allowed_errors.extend(
[
".*basebackup .* failed: invalid basebackup lsn.*",
".*page_service.*handle_make_lsn_lease.*.*tried to request a page version that was garbage collected",
".*/lsn_lease.*invalid lsn lease request.*",
]
)
@@ -108,7 +108,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
assert cur.fetchone() == (1,)
# Create node at pre-initdb lsn
with pytest.raises(Exception, match="invalid basebackup lsn"):
with pytest.raises(Exception, match="invalid lsn lease request"):
# compute node startup with invalid LSN should fail
env.endpoints.create_start(
branch_name="main",
@@ -167,6 +167,23 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
)
return last_flush_lsn
def trigger_gc_and_select(env: NeonEnv, ep_static: Endpoint):
"""
Trigger GC manually on all pageservers. Then run an `SELECT` query.
"""
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Insert some records on main branch
with env.endpoints.create_start("main") as ep_main:
with ep_main.cursor() as cur:
@@ -193,25 +210,31 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
generate_updates_on_main(env, ep_main, i, end=100)
# Trigger GC
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
trigger_gc_and_select(env, ep_static)
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
# Trigger Pageserver restarts
for ps in env.pageservers:
ps.stop()
# Static compute should have at least one lease request failure due to connection.
time.sleep(LSN_LEASE_LENGTH / 2)
ps.start()
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
trigger_gc_and_select(env, ep_static)
# Reconfigure pageservers
env.pageservers[0].stop()
env.storage_controller.node_configure(
env.pageservers[0].id, {"availability": "Offline"}
)
env.storage_controller.reconcile_until_idle()
trigger_gc_and_select(env, ep_static)
# Do some update so we can increment latest_gc_cutoff
generate_updates_on_main(env, ep_main, i, end=100)
# Wait for the existing lease to expire.
time.sleep(LSN_LEASE_LENGTH)
time.sleep(LSN_LEASE_LENGTH + 1)
# Now trigger GC again, layers should be removed.
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()

View File

@@ -567,6 +567,149 @@ def test_storage_controller_compute_hook(
env.storage_controller.consistency_check()
def test_storage_controller_stuck_compute_hook(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
):
"""
Test the migration process's behavior when the compute hook does not enable it to proceed
"""
neon_env_builder.num_pageservers = 2
(host, port) = httpserver_listen_address
neon_env_builder.control_plane_compute_hook_api = f"http://{host}:{port}/notify"
handle_params = {"status": 200}
notifications = []
def handler(request: Request):
status = handle_params["status"]
log.info(f"Notify request[{status}]: {request}")
notifications.append(request.json)
return Response(status=status)
httpserver.expect_request("/notify", method="PUT").respond_with_handler(handler)
# Start running
env = neon_env_builder.init_start(initial_tenant_conf={"lsn_lease_length": "0s"})
# Initial notification from tenant creation
assert len(notifications) == 1
expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = {
"tenant_id": str(env.initial_tenant),
"stripe_size": None,
"shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}],
}
assert notifications[0] == expect
# Do a migration while the compute hook is returning 423 status
tenant_id = env.initial_tenant
origin_pageserver = env.get_tenant_pageserver(tenant_id)
dest_ps_id = [p.id for p in env.pageservers if p.id != origin_pageserver.id][0]
dest_pageserver = env.get_pageserver(dest_ps_id)
shard_0_id = TenantShardId(tenant_id, 0, 0)
NOTIFY_BLOCKED_LOG = ".*Live migration blocked.*"
env.storage_controller.allowed_errors.extend(
[
NOTIFY_BLOCKED_LOG,
".*Failed to notify compute.*",
".*Reconcile error.*Cancelled",
".*Reconcile error.*Control plane tenant busy",
]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# We expect the controller to hit the 423 (locked) and retry. Migration shouldn't complete until that
# status is cleared.
handle_params["status"] = 423
migrate_fut = executor.submit(
env.storage_controller.tenant_shard_migrate, shard_0_id, dest_ps_id
)
def logged_stuck():
env.storage_controller.assert_log_contains(NOTIFY_BLOCKED_LOG)
wait_until(10, 0.25, logged_stuck)
contains_r = env.storage_controller.log_contains(NOTIFY_BLOCKED_LOG)
assert contains_r is not None # Appease mypy
(_, log_cursor) = contains_r
assert migrate_fut.running()
# Permit the compute hook to proceed
handle_params["status"] = 200
migrate_fut.result(timeout=10)
# Advance log cursor past the last 'stuck' message (we already waited for one, but
# there could be more than one)
while True:
contains_r = env.storage_controller.log_contains(NOTIFY_BLOCKED_LOG, offset=log_cursor)
if contains_r is None:
break
else:
(_, log_cursor) = contains_r
# Now, do a migration in the opposite direction
handle_params["status"] = 423
migrate_fut = executor.submit(
env.storage_controller.tenant_shard_migrate, shard_0_id, origin_pageserver.id
)
def logged_stuck_again():
env.storage_controller.assert_log_contains(NOTIFY_BLOCKED_LOG, offset=log_cursor)
wait_until(10, 0.25, logged_stuck_again)
assert migrate_fut.running()
# This time, the compute hook remains stuck, but we mark the origin node offline: this should
# also allow the migration to complete -- we only wait for the compute hook as long as we think
# the old location is still usable for computes.
# This is a regression test for issue https://github.com/neondatabase/neon/issues/8901
dest_pageserver.stop()
env.storage_controller.node_configure(dest_ps_id, {"availability": "Offline"})
try:
migrate_fut.result(timeout=10)
except StorageControllerApiException as e:
# The reconciler will fail because it can't detach from the origin: the important
# thing is that it finishes, rather than getting stuck in the compute notify loop.
assert "Reconcile error" in str(e)
# A later background reconciliation will clean up and leave things in a neat state, even
# while the compute hook is still blocked
try:
env.storage_controller.reconcile_all()
except StorageControllerApiException as e:
# We expect that the reconciler will do its work, but be unable to fully succeed
# because it can't send a compute notification. It will complete, but leave
# the internal flag set for "retry compute notification later"
assert "Control plane tenant busy" in str(e)
# Confirm that we are AttachedSingle on the node we last called the migrate API for
loc = origin_pageserver.http_client().tenant_get_location(shard_0_id)
assert loc["mode"] == "AttachedSingle"
# When the origin node comes back, it should get cleaned up
dest_pageserver.start()
try:
env.storage_controller.reconcile_all()
except StorageControllerApiException as e:
# Compute hook is still blocked: reconciler will configure PS but not fully succeed
assert "Control plane tenant busy" in str(e)
with pytest.raises(PageserverApiException, match="Tenant shard not found"):
dest_pageserver.http_client().tenant_get_location(shard_0_id)
# Once the compute hook is unblocked, we should be able to get into a totally
# quiescent state again
handle_params["status"] = 200
env.storage_controller.reconcile_until_idle()
env.storage_controller.consistency_check()
def test_storage_controller_debug_apis(neon_env_builder: NeonEnvBuilder):
"""
Verify that occasional-use debug APIs work as expected. This is a lightweight test

View File

@@ -27,20 +27,15 @@ def test_empty_tenant_size(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_configs()
env.start()
(tenant_id, _) = env.neon_cli.create_tenant()
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
http_client = env.pageserver.http_client()
initial_size = http_client.tenant_size(tenant_id)
# we should never have zero, because there should be the initdb "changes"
assert initial_size > 0, "initial implementation returns ~initdb tenant_size"
main_branch_name = "main"
branch_name, main_timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
assert branch_name == main_branch_name
endpoint = env.endpoints.create_start(
main_branch_name,
"main",
tenant_id=tenant_id,
config_lines=["autovacuum=off", "checkpoint_timeout=10min"],
)
@@ -54,7 +49,7 @@ def test_empty_tenant_size(neon_env_builder: NeonEnvBuilder):
# The transaction above will make the compute generate a checkpoint.
# In turn, the pageserver persists the checkpoint. This should only be
# one key with a size of a couple hundred bytes.
wait_for_last_flush_lsn(env, endpoint, tenant_id, main_timeline_id)
wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id)
size = http_client.tenant_size(tenant_id)
assert size >= initial_size and size - initial_size < 1024
@@ -306,7 +301,8 @@ def test_single_branch_get_tenant_size_grows(
env = neon_env_builder.init_start(initial_tenant_conf=tenant_config)
tenant_id = env.initial_tenant
branch_name, timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
timeline_id = env.initial_timeline
branch_name = "main"
http_client = env.pageserver.http_client()
@@ -516,7 +512,8 @@ def test_get_tenant_size_with_multiple_branches(
env.pageserver.allowed_errors.append(".*InternalServerError\\(No such file or directory.*")
tenant_id = env.initial_tenant
main_branch_name, main_timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
main_timeline_id = env.initial_timeline
main_branch_name = "main"
http_client = env.pageserver.http_client()

View File

@@ -71,10 +71,9 @@ def test_tenants_many(neon_env_builder: NeonEnvBuilder):
"checkpoint_distance": "5000000",
}
)
env.neon_cli.create_timeline("test_tenants_many", tenant_id=tenant)
endpoint = env.endpoints.create_start(
"test_tenants_many",
"main",
tenant_id=tenant,
)
tenants_endpoints.append((tenant, endpoint))

View File

@@ -638,7 +638,7 @@ def test_delete_timeline_client_hangup(neon_env_builder: NeonEnvBuilder):
wait_until(50, 0.1, first_request_finished)
# check that the timeline is gone
wait_timeline_detail_404(ps_http, env.initial_tenant, child_timeline_id, iterations=2)
wait_timeline_detail_404(ps_http, env.initial_tenant, child_timeline_id, iterations=10)
def test_timeline_delete_works_for_remote_smoke(

View File

@@ -45,10 +45,7 @@ def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool
tenant_after = http.tenant_status(env.initial_tenant)
assert tenant_before != tenant_after
gc_blocking = tenant_after["gc_blocking"]
assert (
gc_blocking
== "BlockingReasons { tenant_blocked_by_lsn_lease_deadline: false, timelines: 1, reasons: EnumSet(Manual) }"
)
assert gc_blocking == "BlockingReasons { timelines: 1, reasons: EnumSet(Manual) }"
wait_for_another_gc_round()
pss.assert_log_contains(gc_skipped_line)

View File

@@ -26,8 +26,7 @@ def test_truncate(neon_env_builder: NeonEnvBuilder, zenbenchmark):
}
)
env.neon_cli.create_timeline("test_truncate", tenant_id=tenant)
endpoint = env.endpoints.create_start("test_truncate", tenant_id=tenant)
endpoint = env.endpoints.create_start("main", tenant_id=tenant)
cur = endpoint.connect().cursor()
cur.execute("create table t1(x integer)")
cur.execute(f"insert into t1 values (generate_series(1,{n_records}))")

View File

@@ -772,7 +772,7 @@ class ProposerPostgres(PgProtocol):
def initdb(self):
"""Run initdb"""
args = ["initdb", "-U", "cloud_admin", "-D", self.pg_data_dir_path()]
args = ["initdb", "--username", "cloud_admin", "--pgdata", self.pg_data_dir_path()]
self.pg_bin.run(args)
def start(self):
@@ -2084,8 +2084,13 @@ def test_timeline_copy(neon_env_builder: NeonEnvBuilder, insert_rows: int):
endpoint.safe_psql("create table t(key int, value text)")
timeline_status = env.safekeepers[0].http_client().timeline_status(tenant_id, timeline_id)
timeline_start_lsn = timeline_status.timeline_start_lsn
# Note: currently timelines on sks are created by compute and commit of
# transaction above is finished when 2/3 sks received it, so there is a
# small chance that timeline on this sk is not created/initialized yet,
# hence the usage of waiting function to prevent flakiness.
timeline_start_lsn = (
env.safekeepers[0].http_client().get_non_zero_timeline_start_lsn(tenant_id, timeline_id)
)
log.info(f"Timeline start LSN: {timeline_start_lsn}")
current_percent = 0.0

View File

@@ -31,7 +31,6 @@ camino = { version = "1", default-features = false, features = ["serde1"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] }
crossbeam-utils = { version = "0.8" }
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
der = { version = "0.7", default-features = false, features = ["oid", "pem", "std"] }
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
@@ -51,7 +50,7 @@ hex = { version = "0.4", features = ["serde"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] }
hyper = { version = "0.14", features = ["full"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12", default-features = false, features = ["use_std"] }
itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" }
itertools-93f6ce9d446188ac = { package = "itertools", version = "0.10" }
lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] }
libc = { version = "0.2", features = ["extra_traits", "use_std"] }
@@ -68,8 +67,7 @@ rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" }
reqwest-5ef9efb8ec2df382 = { package = "reqwest", version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
reqwest-a6292c17cd707f01 = { package = "reqwest", version = "0.11", default-features = false, features = ["blocking", "rustls-tls", "stream"] }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
@@ -86,12 +84,9 @@ tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net",
tokio-rustls = { version = "0.24" }
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
toml_edit = { version = "0.22", features = ["serde"] }
tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
tracing-core = { version = "0.1" }
tracing-log = { version = "0.1", default-features = false, features = ["log-tracer", "std"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt", "json", "smallvec", "tracing-log"] }
url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive", "serde"] }
@@ -110,7 +105,7 @@ getrandom = { version = "0.2", default-features = false, features = ["std"] }
half = { version = "2", default-features = false, features = ["num-traits"] }
hashbrown = { version = "0.14", features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12", default-features = false, features = ["use_std"] }
itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" }
itertools-93f6ce9d446188ac = { package = "itertools", version = "0.10" }
lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] }
libc = { version = "0.2", features = ["extra_traits", "use_std"] }