Compare commits

..

43 Commits

Author SHA1 Message Date
Anna Khanova
e39f0a2347 Change default timeout 2024-02-10 11:49:45 +01:00
Anna Khanova
8bf1aacb24 Create timeout on proxy<->cplane communication 2024-02-10 00:45:36 +01:00
Christian Schwarz
5779c7908a revert two recent heavier_once_cell changes (#6704)
This PR reverts

- https://github.com/neondatabase/neon/pull/6589
- https://github.com/neondatabase/neon/pull/6652

because there's a performance regression that's particularly visible at
high layer counts.

Most likely it's because the switch to RwLock inflates the 

```
    inner: heavier_once_cell::OnceCell<ResidentOrWantedEvicted>,
```

size from 48 to 88 bytes, which, by itself is almost a doubling of the
cache footprint, and probably the fact that it's now larger than a cache
line also doesn't help.

See this chat on the Neon discord for more context:

https://discord.com/channels/1176467419317940276/1204714372295958548/1205541184634617906

I'm reverting 6652 as well because it might also have perf implications,
and we're getting close to the next release. We should re-do its changes
after the next release, though.

cc @koivunej 
cc @ivaxer
2024-02-09 22:22:40 +00:00
Sasha Krassovsky
1a4dd58b70 Grant pg_monitor to neon_superuser (#6691)
## Problem
The people want pg_monitor
https://github.com/neondatabase/neon/issues/6682
## Summary of changes
Gives the people pg_monitor
2024-02-09 20:22:53 +00:00
Conrad Ludgate
cbd3a32d4d proxy: decode username and password (#6700)
## Problem

usernames and passwords can be URL 'percent' encoded in the connection
string URL provided by serverless driver.

## Summary of changes

Decode the parameters when getting conn info
2024-02-09 19:22:23 +00:00
Christian Schwarz
ca818c8bd7 fix(test_ondemand_download_timetravel): occasionally fails with slightly higher physical size (#6687) 2024-02-09 20:09:37 +01:00
Arseny Sher
1bb9abebf2 Remove WAL segments from s3 in batches.
Do list-delete operations in batches instead of doing full list first, to ensure
deletion makes progress even if there are a lot of files to remove.

To this end, add max_keys limit to remote storage list_files.
2024-02-09 22:11:53 +04:00
Conrad Ludgate
96d89cde51 Proxy error reworking (#6453)
## Problem

Taking my ideas from https://github.com/neondatabase/neon/pull/6283 and
doing a bit less radical changes. smaller commits.

We currently don't report error classifications in proxy as the current
error handling made it hard to do so.

## Summary of changes

1. Add a `ReportableError` trait that all errors will implement. This
provides the error classification functionality.
2. Handle Client requests a strongly typed error
    * this error is a `ReportableError` and is logged appropriately
3. The handle client error only has a few possible error types, to
account for the fact that at this point errors should be returned to the
user.
2024-02-09 15:50:51 +00:00
John Spray
89a5c654bf control_plane: follow up for embedded migrations (#6647)
## Problem

In https://github.com/neondatabase/neon/pull/6637, we remove the need to
run migrations externally, but for compat tests to work we can't remove
those invocations from the neon_local binary.

Once that previous PR merges, we can make the followup changes without
upsetting compat tests.
2024-02-09 14:26:50 +00:00
Heikki Linnakangas
5239cdc29f Fix test_vm_bit_clear_on_heap_lock test
The test was supposed to reproduce the bug fixed in commit 66fa176cc8,
i.e. that the clearing of the VM bit was not replayed in the
pageserver on HEAP_LOCK records. But it was broken in many ways and
failed to reproduce the original problem if you reverted the fix:

- The comparison of XIDs was broken. The test read the XID in to a
  variable in python, but it was treated as a string rather than an
  integer. As a result, e.g. "999" > "1000".

- The test accessed the locked tuple too early, in the loop. Accessing
  it early, before the pg_xact page had been removed, set the hint bits.
  That masked the problem on subsequent accesses.

- The on-demand SLRU download that was introduced in commit 9a9d9beaee
  hid the issue. Even though an SLRU segment was removed by Postgres,
  when it later tried to access it, it could still download it from
  the pageserver. To ensure that doesn't happen, shorten the GC period
  and compact and GC aggressively in the test.

I also added a more direct check that the VM page is updated, using
the get_page_at_lsn() debugging function. Right after locking the row,
we now fetch the VM page from pageserver and directly compare it with
the VM page in the page cache. They should match. That assertion is
more robust to things like on-demand SLRU download that could mask the
bug.
2024-02-09 15:56:41 +02:00
Heikki Linnakangas
84a0e7b022 tests: Allow setting shutdown mode separately from 'destroy' flag
In neon_local, the default mode is now always 'fast', regardless of
'destroy'. You can override it with the "neon_local endpoint stop
--mode=immediate" flag.

In python tests, we still default to 'immediate' mode when using the
stop_and_destroy() function, and 'fast' with plain stop(). I kept that
to avoid changing behavior in existing tests. I don't think existing
tests depend on it, but I wasn't 100% certain.
2024-02-09 15:56:41 +02:00
John Spray
8d98981fe5 tests: deflake test_sharding_split_unsharded (#6699)
## Problem

This test was a subset of the larger sharding test, and it missed the
validate() call on workload that was implicitly waiting for a tenant to
become active before trying to split it. It could therefore fail to
split due to tenant not yet being active.

## Summary of changes

- Insert .validate() call, and move the Workload setup to after the
check of shard ID (as the shard ID check should pass immediately)
2024-02-09 13:20:04 +00:00
Joonas Koivunen
eb919cab88 prepare to move timeouts and cancellation handling to remote_storage (#6696)
This PR is preliminary cleanups and refactoring around `remote_storage`
for next PR which will move the timeouts and cancellation into
`remote_storage`.

Summary:
- smaller drive-by fixes
- code simplification
- refactor common parts like `DownloadError::is_permanent`
- align error types with `RemoteStorage::list_*` to use more
`download_retry` helper

Cc: #6096
2024-02-09 12:52:58 +00:00
Anastasia Lubennikova
eec1e1a192 Pre-install anon extension from compute_ctl
if anon is in shared_preload_libraries.
Users cannot install it themselves, because superuser is required.

GRANT all priveleged needed to use it to db_owner

We use the neon fork of the extension, because small change to sql file
is needed to allow db_owner to use it.

This feature is behind a feature flag AnonExtension,
so it is not enabled by default.
2024-02-09 12:32:07 +00:00
Conrad Ludgate
ea089dc977 proxy: add per query array mode flag (#6678)
## Problem

Drizzle needs to be able to configure the array_mode flag per query.

## Summary of changes

Adds an array_mode flag to the query data json that will otherwise
default to the header flag.
2024-02-09 10:29:20 +00:00
John Spray
951c9bf4ca control_plane: fix shard splitting on unsharded tenant (#6689)
## Problem

Previous test started with a new-style TenantShardId with a non-zero
ShardCount. We also need to handle the case of a ShardCount() (aka
`unsharded`) parent shard.

**A followup PR will refactor ShardCount to make its inner value private
and thereby make this kind of mistake harder**

## Summary of changes

- Fix a place we were incorrectly treating a ShardCount as a number of
shards rather than as thing that can be zero or the number of shards.
- Add a test for this case.
2024-02-09 10:12:40 +00:00
Heikki Linnakangas
568f91420a tests: try to make restored-datadir comparison tests not flaky (#6666)
This test occasionally fails with a difference in "pg_xact/0000" file
between the local and restored datadirs. My hypothesis is that something
changed in the database between the last explicit checkpoint and the
shutdown. I suspect autovacuum, it could certainly create transactions.

To fix, be more precise about the point in time that we compare. Shut
down the endpoint first, then read the last LSN (i.e. the shutdown
checkpoint's LSN), from the local disk with pg_controldata. And use
exactly that LSN in the basebackup.

Closes #559.

I'm proposing this as an alternative to
https://github.com/neondatabase/neon/pull/6662.
2024-02-09 11:34:15 +02:00
Joonas Koivunen
a18aa14754 test: shutdown endpoints before deletion (#6619)
this avoids a page_service error in the log sometimes. keeping the
endpoint running while deleting has no function for this test.
2024-02-09 09:01:07 +00:00
Konstantin Knizhnik
529a79d263 Increment generation which LFC is disabled by assigning 0 to neon.file_cache_size_limit (#6692)
## Problem

test_lfc_resize sometimes filed with assertion failure when require lock
in write operation:

```
	if (lfc_ctl->generation == generation)
	{
		Assert(LFC_ENABLED());
```

## Summary of changes

Increment generation when 0 is assigned to neon.file_cache_size_limit

## Checklist before requesting a review

- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2024-02-09 08:14:41 +02:00
Joonas Koivunen
c09993396e fix: secondary tenant relative order eviction (#6491)
Calculate the `relative_last_activity` using the total evicted and
resident layers similar to what we originally planned.

Cc: #5331
2024-02-09 00:37:57 +02:00
Joonas Koivunen
9a31311990 fix(heavier_once_cell): assertion failure can be hit (#6652)
@problame noticed that the `tokio::sync::AcquireError` branch assertion
can be hit like in the first commit. We haven't seen this yet in
production, but I'd prefer not to see it there. There `take_and_deinit`
is being used, but this race must be quite timing sensitive.
2024-02-08 22:40:14 +02:00
Arpad Müller
c0e0fc8151 Update Rust to 1.76.0 (#6683)
[Release notes](https://github.com/rust-lang/rust/releases/tag/1.75.0).
2024-02-08 19:57:02 +01:00
John Spray
e8d2843df6 storage controller: improved handling of node availability on restart (#6658)
- Automatically set a node's availability to Active if it is responsive
in startup_reconcile
- Impose a 5s timeout of HTTP request to list location conf, so that an
unresponsive node can't hang it for minutes
- Do several retries if the request fails with a retryable error, to be
tolerant of concurrent pageserver & storage controller restarts
- Add a readiness hook for use with k8s so that we can tell when the
startup reconciliaton is done and the service is fully ready to do work.
- Add /metrics to the list of un-authenticated endpoints (this is
unrelated but we're touching the line in this PR already, and it fixes
auth error spam in deployed container.)
- A test for the above.

Closes: #6670
2024-02-08 18:00:53 +00:00
John Spray
af91a28936 pageserver: shard splitting (#6379)
## Problem

One doesn't know at tenant creation time how large the tenant will grow.
We need to be able to dynamically adjust the shard count at runtime.
This is implemented as "splitting" of shards into smaller child shards,
which cover a subset of the keyspace that the parent covered.

Refer to RFC: https://github.com/neondatabase/neon/pull/6358

Part of epic: #6278

## Summary of changes

This PR implements the happy path (does not cleanly recover from a crash
mid-split, although won't lose any data), without any optimizations
(e.g. child shards re-download their own copies of layers that the
parent shard already had on local disk)

- Add `/v1/tenant/:tenant_shard_id/shard_split` API to pageserver: this
copies the shard's index to the child shards' paths, instantiates child
`Tenant` object, and tears down parent `Tenant` object.
- Add `splitting` column to `tenant_shards` table. This is written into
an existing migration because we haven't deployed yet, so don't need to
cleanly upgrade.
- Add `/control/v1/tenant/:tenant_id/shard_split` API to
attachment_service,
- Add `test_sharding_split_smoke` test. This covers the happy path:
future PRs will add tests that exercise failure cases.
2024-02-08 15:35:13 +00:00
Konstantin Knizhnik
43eae17f0d Drop unused replication slots (#6655)
## Problem

See #6626

If there is inactive replication slot then Postgres will not bw able to
shrink WAL and delete unused snapshots.
If she other active subscription is present, then snapshots created each
15 seconds will overflow AUX_DIR.

Setting `max_slot_wal_keep_size` doesn't solve the problem, because even
small WAL segment will be enough to overflow AUX_DIR if there is no
other activity on the system.

## Summary of changes

If there are active subscriptions and some logical replication slots are
not used during `neon.logical_replication_max_time_lag` interval, then
unused slot is dropped.

## Checklist before requesting a review

- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2024-02-08 17:31:15 +02:00
Anna Khanova
6c34d4cd14 Proxy: set timeout on establishing connection (#6679)
## Problem

There is no timeout on the handshake.

## Summary of changes

Set the timeout on the establishing connection.
2024-02-08 13:52:04 +00:00
Anna Khanova
c63e3e7e84 Proxy: improve http-pool (#6577)
## Problem

The password check logic for the sql-over-http is a bit non-intuitive. 

## Summary of changes

1. Perform scram auth using the same logic as for websocket cleartext
password.
2. Split establish connection logic and connection pool.
3. Parallelize param parsing logic with authentication + wake compute.
4. Limit the total number of clients
2024-02-08 12:57:05 +01:00
Christian Schwarz
c52495774d tokio-epoll-uring: expose its metrics in pageserver's /metrics (#6672)
context: https://github.com/neondatabase/neon/issues/6667
2024-02-07 23:58:54 +00:00
Andreas Scherbaum
9a017778a9 Update copyright notice, set it to current year (#6671)
## Problem

Copyright notice is outdated

## Summary of changes

Replace the initial year `2022` with `2022 - 2024`, after brief
discussion with Stas about the format

Co-authored-by: Andreas Scherbaum <andreas@neon.tech>
2024-02-08 00:48:31 +01:00
Christian Schwarz
c561ad4e2e feat: expose locked memory in pageserver /metrics (#6669)
context: https://github.com/neondatabase/neon/issues/6667
2024-02-07 19:39:52 +00:00
John Spray
3bd2a4fd56 control_plane: avoid feedback loop with /location_config if compute hook fails. (#6668)
## Problem

The existing behavior isn't exactly incorrect, but is operationally
risky: if the control plane compute hook breaks, then all the control
plane operations trying to call /location_config will end up retrying
forever, which could put more load on the system.

## Summary of changes

- Treat 404s as fatal errors to do fewer retries: a 404 either indicates
we have the wrong URL, or some control plane bug is failing to recognize
our tenant ID as existing.
- Do not return an error on reconcilation errors in a non-creating
/location_config response: this allows the control plane to finish its
Operation (and we will eventually retry the compute notification later)
2024-02-07 19:14:18 +00:00
Tristan Partin
128fae7054 Update Postgres 16 to 16.2 2024-02-07 11:10:48 -08:00
Tristan Partin
5541244dc4 Update Postgres 15 to 15.6 2024-02-07 11:10:48 -08:00
Tristan Partin
2e9b1f7aaf Update Postgres 14 to 14.11 2024-02-07 11:10:48 -08:00
Christian Schwarz
51f9385b1b live-reconfigurable virtual_file::IoEngine (#6552)
This PR adds an API to live-reconfigure the VirtualFile io engine.

It also adds a flag to `pagebench get-page-latest-lsn`, which is where I
found this functionality to be useful: it helps compare the io engines
in a benchmark without re-compiling a release build, which took ~50s on
the i3en.3xlarge where I was doing the benchmark.

Switching the IO engine is completely safe at runtime.
2024-02-07 17:47:55 +00:00
Sasha Krassovsky
7b49e5e5c3 Remove compute migrations feature flag (#6653) 2024-02-07 07:55:55 -09:00
Abhijeet Patil
75f1a01d4a Optimise e2e run (#6513)
## Problem
We have finite amount of runners and intermediate results are often
wanted before a PR is ready for merging. Currently all PRs get e2e tests
run and this creates a lot of throwaway e2e results which may or may not
get to start or complete before a new push.

## Summary of changes

1. Skip e2e test when PR is in draft mode
2. Run e2e when PR status changes from draft to ready for review (change
this to having its trigger in below PR and update results of build and
test)
3. Abstract e2e test in a Separate workflow and call it from the main
workflow for the e2e test
5. Add a label, if that label is present run e2e test in draft
(run-e2e-test-in-draft)
6. Auto add a label(approve to ci) so that all the external contributors
PR , e2e run in draft
7. Document the new label changes and the above behaviour

Draft PR  : https://github.com/neondatabase/neon/actions/runs/7729128470
Ready To Review :
https://github.com/neondatabase/neon/actions/runs/7733779916
Draft PR with label :
https://github.com/neondatabase/neon/actions/runs/7725691012/job/21062432342
and https://github.com/neondatabase/neon/actions/runs/7733854028

## Checklist before requesting a review

- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

---------

Co-authored-by: Alexander Bayandin <alexander@neon.tech>
2024-02-07 16:14:10 +00:00
John Spray
090a789408 storage controller: use PUT instead of POST (#6659)
This was a typo, the server expects PUT.
2024-02-07 13:24:10 +00:00
John Spray
3d4fe205ba control_plane/attachment_service: database connection pool (#6622)
## Problem

This is mainly to limit our concurrency, rather than to speed up
requests (I was doing some sanity checks on performance of the service
with thousands of shards)

## Summary of changes

- Enable the `diesel:r2d2` feature, which provides an async connection
pool
- Acquire a connection before entering spawn_blocking for a database
transaction (recall that diesel's interface is sync)
- Set a connection pool size of 99 to fit within default postgres limit
(100)
- Also set the tokio blocking thread count to accomodate the same number
of blocking tasks (the only thing we use spawn_blocking for is database
calls).
2024-02-07 13:08:09 +00:00
Arpad Müller
f7516df6c1 Pass timestamp as a datetime (#6656)
This saves some repetition. I did this in #6533 for
`tenant_time_travel_remote_storage` already.
2024-02-07 12:56:53 +01:00
Konstantin Knizhnik
f3d7d23805 Some small WAL records can write a lot of data to KV storage, so perform checkpoint check more frequently (#6639)
## Problem

See
https://neondb.slack.com/archives/C04DGM6SMTM/p1707149618314539?thread_ts=1707081520.140049&cid=C04DGM6SMTM

## Summary of changes


Perform checkpoint check after processing `ingest_batch_size` (default
100) WAL records.

## Checklist before requesting a review

- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2024-02-07 08:47:19 +02:00
Alexander Bayandin
9f75da7c0a test_lazy_startup: fix statement_timeout setting (#6654)
## Problem
Test `test_lazy_startup` is flaky[0], sometimes (pretty frequently) it
fails with `canceling statement due to statement timeout`.

- [0]
https://neon-github-public-dev.s3.amazonaws.com/reports/main/7803316870/index.html#suites/355b1a7a5b1e740b23ea53728913b4fa/7263782d30986c50/history

## Summary of changes
- Fix setting `statement_timeout` setting by reusing a connection for
all queries.
- Also fix label (`lazy`, `eager`) assignment  
- Split `test_lazy_startup` into two, by `slru` laziness and make tests smaller
2024-02-07 00:31:26 +00:00
Alexander Bayandin
f4cc7cae14 CI(build-tools): Update Python from 3.9.2 to 3.9.18 (#6615)
## Problem

We use an outdated version of Python (3.9.2)

## Summary of changes
- Update Python to the latest patch version (3.9.18)
- Unify the usage of python caches where possible
2024-02-06 20:30:43 +00:00
127 changed files with 4311 additions and 1558 deletions

View File

@@ -179,6 +179,12 @@ runs:
aws s3 rm "s3://${BUCKET}/${LOCK_FILE}"
fi
- name: Cache poetry deps
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Store Allure test stat in the DB (new)
if: ${{ !cancelled() && inputs.store-test-results-into-db == 'true' }}
shell: bash -euxo pipefail {0}

View File

@@ -86,11 +86,10 @@ runs:
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -euxo pipefail {0}

View File

@@ -93,6 +93,7 @@ jobs:
--body-file "body.md" \
--head "${BRANCH}" \
--base "main" \
--label "run-e2e-tests-in-draft" \
--draft
fi

View File

@@ -22,7 +22,7 @@ env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
# A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
jobs:
check-permissions:
@@ -112,11 +112,10 @@ jobs:
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v1-codestyle-python-deps-${{ hashFiles('poetry.lock') }}
key: v2-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
run: ./scripts/pysync
@@ -693,50 +692,10 @@ jobs:
})
trigger-e2e-tests:
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' }}
needs: [ check-permissions, promote-images, tag ]
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
options: --init
steps:
- name: Set PR's status to pending and request a remote CI test
run: |
# For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit
# but we need to use a real sha of a latest commit in the PR's branch for the e2e job,
# to place a job run status update later.
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
# For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
REMOTE_REPO="${{ github.repository_owner }}/cloud"
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$COMMIT_SHA\",
\"remote_repo\": \"${{ github.repository }}\",
\"storage_image_tag\": \"${{ needs.tag.outputs.build-tag }}\",
\"compute_image_tag\": \"${{ needs.tag.outputs.build-tag }}\",
\"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\"
}
}"
uses: ./.github/workflows/trigger-e2e-tests.yml
secrets: inherit
neon-image:
needs: [ check-permissions, build-buildtools-image, tag ]

View File

@@ -38,11 +38,10 @@ jobs:
uses: snok/install-poetry@v1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
key: v2-${{ runner.os }}-python-deps-ubunutu-latest-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -euxo pipefail {0}

118
.github/workflows/trigger-e2e-tests.yml vendored Normal file
View File

@@ -0,0 +1,118 @@
name: Trigger E2E Tests
on:
pull_request:
types:
- ready_for_review
workflow_call:
defaults:
run:
shell: bash -euxo pipefail {0}
env:
# A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
jobs:
cancel-previous-e2e-tests:
if: github.event_name == 'pull_request'
runs-on: ubuntu-latest
steps:
- name: Cancel previous e2e-tests runs for this PR
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
run: |
gh workflow --repo neondatabase/cloud \
run cancel-previous-in-concurrency-group.yml \
--field concurrency_group="${{ env.E2E_CONCURRENCY_GROUP }}"
tag:
runs-on: [ ubuntu-latest ]
outputs:
build-tag: ${{ steps.build-tag.outputs.tag }}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get build tag
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
CURRENT_BRANCH: ${{ github.head_ref || github.ref_name }}
CURRENT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
echo "tag=$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
echo "tag=release-$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
BUILD_AND_TEST_RUN_ID=$(gh run list -b $CURRENT_BRANCH -c $CURRENT_SHA -w 'Build and Test' -L 1 --json databaseId --jq '.[].databaseId')
echo "tag=$BUILD_AND_TEST_RUN_ID" | tee -a $GITHUB_OUTPUT
fi
id: build-tag
trigger-e2e-tests:
needs: [ tag ]
runs-on: [ self-hosted, gen3, small ]
env:
TAG: ${{ needs.tag.outputs.build-tag }}
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
options: --init
steps:
- name: check if ecr image are present
run: |
for REPO in neon compute-tools compute-node-v14 vm-compute-node-v14 compute-node-v15 vm-compute-node-v15 compute-node-v16 vm-compute-node-v16; do
OUTPUT=$(aws ecr describe-images --repository-name ${REPO} --region eu-central-1 --query "imageDetails[?imageTags[?contains(@, '${TAG}')]]" --output text)
if [ "$OUTPUT" == "" ]; then
echo "$REPO with image tag $TAG not found" >> $GITHUB_OUTPUT
exit 1
fi
done
- name: Set PR's status to pending and request a remote CI test
run: |
# For pull requests, GH Actions set "github.sha" variable to point at a fake merge commit
# but we need to use a real sha of a latest commit in the PR's branch for the e2e job,
# to place a job run status update later.
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
# For non-PR kinds of runs, the above will produce an empty variable, pick the original sha value for those
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
REMOTE_REPO="${{ github.repository_owner }}/cloud"
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$COMMIT_SHA\",
\"remote_repo\": \"${{ github.repository }}\",
\"storage_image_tag\": \"${TAG}\",
\"compute_image_tag\": \"${TAG}\",
\"concurrency_group\": \"${{ env.E2E_CONCURRENCY_GROUP }}\"
}
}"

View File

@@ -54,6 +54,9 @@ _An instruction for maintainers_
- If and only if it looks **safe** (i.e. it doesn't contain any malicious code which could expose secrets or harm the CI), then:
- Press the "Approve and run" button in GitHub UI
- Add the `approved-for-ci-run` label to the PR
- Currently draft PR will skip e2e test (only for internal contributors). After turning the PR 'Ready to Review' CI will trigger e2e test
- Add `run-e2e-tests-in-draft` label to run e2e test in draft PR (override above behaviour)
- The `approved-for-ci-run` workflow will add `run-e2e-tests-in-draft` automatically to run e2e test for external contributors
Repeat all steps after any change to the PR.
- When the changes are ready to get merged — merge the original PR (not the internal one)

35
Cargo.lock generated
View File

@@ -289,6 +289,7 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"postgres_connection",
"r2d2",
"reqwest",
"serde",
"serde_json",
@@ -1328,8 +1329,6 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"diesel",
"diesel_migrations",
"futures",
"git-version",
"hex",
@@ -1651,6 +1650,7 @@ dependencies = [
"diesel_derives",
"itoa",
"pq-sys",
"r2d2",
"serde_json",
]
@@ -2867,6 +2867,7 @@ dependencies = [
"chrono",
"libc",
"once_cell",
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
@@ -3984,6 +3985,8 @@ checksum = "b1de8dacb0873f77e6aefc6d71e044761fcc68060290f5b1089fcdf84626bb69"
dependencies = [
"bitflags 1.3.2",
"byteorder",
"chrono",
"flate2",
"hex",
"lazy_static",
"rustix 0.36.16",
@@ -4074,6 +4077,7 @@ dependencies = [
"clap",
"consumption_metrics",
"dashmap",
"env_logger",
"futures",
"git-version",
"hashbrown 0.13.2",
@@ -4121,6 +4125,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"smallvec",
"smol_str",
"socket2 0.5.5",
"sync_wrapper",
@@ -4139,6 +4144,7 @@ dependencies = [
"tracing-subscriber",
"tracing-utils",
"url",
"urlencoding",
"utils",
"uuid",
"walkdir",
@@ -4166,6 +4172,17 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r2d2"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93"
dependencies = [
"log",
"parking_lot 0.12.1",
"scheduled-thread-pool",
]
[[package]]
name = "rand"
version = "0.7.3"
@@ -4879,6 +4896,15 @@ dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "scheduled-thread-pool"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19"
dependencies = [
"parking_lot 0.12.1",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -5714,7 +5740,7 @@ dependencies = [
[[package]]
name = "tokio-epoll-uring"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
dependencies = [
"futures",
"nix 0.26.4",
@@ -6239,7 +6265,7 @@ dependencies = [
[[package]]
name = "uring-common"
version = "0.1.0"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#0e1af4ccddf2f01805cfc9eaefa97ee13c04b52d"
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
dependencies = [
"io-uring",
"libc",
@@ -6806,7 +6832,6 @@ dependencies = [
"clap",
"clap_builder",
"crossbeam-utils",
"diesel",
"either",
"fail",
"futures-channel",

View File

@@ -113,6 +113,7 @@ parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
parquet_derive = "49.0.0"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pin-project-lite = "0.2"
procfs = "0.14"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
@@ -170,6 +171,7 @@ tracing-opentelemetry = "0.20.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
twox-hash = { version = "1.6.3", default-features = false }
url = "2.2"
urlencoding = "2.1"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
webpki-roots = "0.25"

View File

@@ -100,6 +100,11 @@ RUN mkdir -p /data/.neon/ && chown -R neon:neon /data/.neon/ \
-c "listen_pg_addr='0.0.0.0:6400'" \
-c "listen_http_addr='0.0.0.0:9898'"
# When running a binary that links with libpq, default to using our most recent postgres version. Binaries
# that want a particular postgres version will select it explicitly: this is just a default.
ENV LD_LIBRARY_PATH /usr/local/v16/lib
VOLUME ["/data"]
USER neon
EXPOSE 6400

View File

@@ -111,7 +111,7 @@ USER nonroot:nonroot
WORKDIR /home/nonroot
# Python
ENV PYTHON_VERSION=3.9.2 \
ENV PYTHON_VERSION=3.9.18 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
@@ -135,7 +135,7 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.75.0
ENV RUSTC_VERSION=1.76.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \

View File

@@ -639,8 +639,8 @@ FROM build-deps AS pg-anon-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/1.1.0/postgresql_anonymizer-1.1.0.tar.gz -O pg_anon.tar.gz && \
echo "08b09d2ff9b962f96c60db7e6f8e79cf7253eb8772516998fc35ece08633d3ad pg_anon.tar.gz" | sha256sum --check && \
RUN wget https://github.com/neondatabase/postgresql_anonymizer/archive/refs/tags/neon_1.1.1.tar.gz -O pg_anon.tar.gz && \
echo "321ea8d5c1648880aafde850a2c576e4a9e7b9933a34ce272efc839328999fa9 pg_anon.tar.gz" | sha256sum --check && \
mkdir pg_anon-src && cd pg_anon-src && tar xvzf ../pg_anon.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -809,6 +809,7 @@ COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-semver-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY --from=pg-anon-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

2
NOTICE
View File

@@ -1,5 +1,5 @@
Neon
Copyright 2022 Neon Inc.
Copyright 2022 - 2024 Neon Inc.
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
See vendor/postgres-vX/COPYRIGHT for details.

View File

@@ -765,7 +765,12 @@ impl ComputeNode {
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, connstr.as_str(), &mut client)?;
handle_grants(spec, &mut client, connstr.as_str())?;
handle_grants(
spec,
&mut client,
connstr.as_str(),
self.has_feature(ComputeFeature::AnonExtension),
)?;
handle_extensions(spec, &mut client)?;
handle_extension_neon(&mut client)?;
create_availability_check_data(&mut client)?;
@@ -773,12 +778,11 @@ impl ComputeNode {
// 'Close' connection
drop(client);
if self.has_feature(ComputeFeature::Migrations) {
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client)
});
}
// Run migrations separately to not hold up cold starts
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client)
});
Ok(())
}
@@ -840,7 +844,12 @@ impl ComputeNode {
handle_roles(&spec, &mut client)?;
handle_databases(&spec, &mut client)?;
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
handle_grants(&spec, &mut client, self.connstr.as_str())?;
handle_grants(
&spec,
&mut client,
self.connstr.as_str(),
self.has_feature(ComputeFeature::AnonExtension),
)?;
handle_extensions(&spec, &mut client)?;
handle_extension_neon(&mut client)?;
// We can skip handle_migrations here because a new migration can only appear

View File

@@ -264,9 +264,10 @@ pub fn wait_for_postgres(pg: &mut Child, pgdata: &Path) -> Result<()> {
// case we miss some events for some reason. Not strictly necessary, but
// better safe than sorry.
let (tx, rx) = std::sync::mpsc::channel();
let (mut watcher, rx): (Box<dyn Watcher>, _) = match notify::recommended_watcher(move |res| {
let watcher_res = notify::recommended_watcher(move |res| {
let _ = tx.send(res);
}) {
});
let (mut watcher, rx): (Box<dyn Watcher>, _) = match watcher_res {
Ok(watcher) => (Box::new(watcher), rx),
Err(e) => {
match e.kind {

View File

@@ -581,7 +581,12 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
#[instrument(skip_all)]
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
pub fn handle_grants(
spec: &ComputeSpec,
client: &mut Client,
connstr: &str,
enable_anon_extension: bool,
) -> Result<()> {
info!("modifying database permissions");
let existing_dbs = get_existing_dbs(client)?;
@@ -678,6 +683,11 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) ->
inlinify(&grant_query)
);
db_client.simple_query(&grant_query)?;
// it is important to run this after all grants
if enable_anon_extension {
handle_extension_anon(spec, &db.owner, &mut db_client, false)?;
}
}
Ok(())
@@ -766,6 +776,7 @@ BEGIN
END IF;
END
$$;"#,
"GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION",
];
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
@@ -809,5 +820,125 @@ $$;"#,
"Ran {} migrations",
(migrations.len() - starting_migration_id)
);
Ok(())
}
/// Connect to the database as superuser and pre-create anon extension
/// if it is present in shared_preload_libraries
#[instrument(skip_all)]
pub fn handle_extension_anon(
spec: &ComputeSpec,
db_owner: &str,
db_client: &mut Client,
grants_only: bool,
) -> Result<()> {
info!("handle extension anon");
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("anon") {
if !grants_only {
// check if extension is already initialized using anon.is_initialized()
let query = "SELECT anon.is_initialized()";
match db_client.query(query, &[]) {
Ok(rows) => {
if !rows.is_empty() {
let is_initialized: bool = rows[0].get(0);
if is_initialized {
info!("anon extension is already initialized");
return Ok(());
}
}
}
Err(e) => {
warn!(
"anon extension is_installed check failed with expected error: {}",
e
);
}
};
// Create anon extension if this compute needs it
// Users cannot create it themselves, because superuser is required.
let mut query = "CREATE EXTENSION IF NOT EXISTS anon CASCADE";
info!("creating anon extension with query: {}", query);
match db_client.query(query, &[]) {
Ok(_) => {}
Err(e) => {
error!("anon extension creation failed with error: {}", e);
return Ok(());
}
}
// check that extension is installed
query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
let rows = db_client.query(query, &[])?;
if rows.is_empty() {
error!("anon extension is not installed");
return Ok(());
}
// Initialize anon extension
// This also requires superuser privileges, so users cannot do it themselves.
query = "SELECT anon.init()";
match db_client.query(query, &[]) {
Ok(_) => {}
Err(e) => {
error!("anon.init() failed with error: {}", e);
return Ok(());
}
}
}
// check that extension is installed, if not bail early
let query = "SELECT extname FROM pg_extension WHERE extname = 'anon'";
match db_client.query(query, &[]) {
Ok(rows) => {
if rows.is_empty() {
error!("anon extension is not installed");
return Ok(());
}
}
Err(e) => {
error!("anon extension check failed with error: {}", e);
return Ok(());
}
};
let query = format!("GRANT ALL ON SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
// Grant permissions to db_owner to use anon extension functions
let query = format!("GRANT ALL ON ALL FUNCTIONS IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
// This is needed, because some functions are defined as SECURITY DEFINER.
// In Postgres SECURITY DEFINER functions are executed with the privileges
// of the owner.
// In anon extension this it is needed to access some GUCs, which are only accessible to
// superuser. But we've patched postgres to allow db_owner to access them as well.
// So we need to change owner of these functions to db_owner.
let query = format!("
SELECT 'ALTER FUNCTION '||nsp.nspname||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||') OWNER TO {};'
from pg_proc p
join pg_namespace nsp ON p.pronamespace = nsp.oid
where nsp.nspname = 'anon';", db_owner);
info!("change anon extension functions owner to db owner");
db_client.simple_query(&query)?;
// affects views as well
let query = format!("GRANT ALL ON ALL TABLES IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
let query = format!("GRANT ALL ON ALL SEQUENCES IN SCHEMA anon TO {}", db_owner);
info!("granting anon extension permissions with query: {}", query);
db_client.simple_query(&query)?;
}
}
Ok(())
}

View File

@@ -10,8 +10,6 @@ async-trait.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
diesel = { version = "2.1.4", features = ["postgres"]}
diesel_migrations = { version = "2.1.0", features = ["postgres"]}
futures.workspace = true
git-version.workspace = true
nix.workspace = true

View File

@@ -24,8 +24,9 @@ tokio.workspace = true
tokio-util.workspace = true
tracing.workspace = true
diesel = { version = "2.1.4", features = ["serde_json", "postgres"] }
diesel = { version = "2.1.4", features = ["serde_json", "postgres", "r2d2"] }
diesel_migrations = { version = "2.1.0" }
r2d2 = { version = "0.8.10" }
utils = { path = "../../libs/utils/" }
metrics = { path = "../../libs/metrics/" }

View File

@@ -7,6 +7,7 @@ CREATE TABLE tenant_shards (
generation INTEGER NOT NULL,
generation_pageserver BIGINT NOT NULL,
placement_policy VARCHAR NOT NULL,
splitting SMALLINT NOT NULL,
-- config is JSON encoded, opaque to the database.
config TEXT NOT NULL
);

View File

@@ -170,7 +170,7 @@ impl ComputeHook {
reconfigure_request: &ComputeHookNotifyRequest,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let req = client.request(Method::POST, url);
let req = client.request(Method::PUT, url);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value)
} else {
@@ -240,7 +240,7 @@ impl ComputeHook {
let client = reqwest::Client::new();
backoff::retry(
|| self.do_notify_iteration(&client, url, &reconfigure_request, cancel),
|e| matches!(e, NotifyError::Fatal(_)),
|e| matches!(e, NotifyError::Fatal(_) | NotifyError::Unexpected(_)),
3,
10,
"Send compute notification",

View File

@@ -3,7 +3,8 @@ use crate::service::{Service, STARTUP_RECONCILE_TIMEOUT};
use hyper::{Body, Request, Response};
use hyper::{StatusCode, Uri};
use pageserver_api::models::{
TenantCreateRequest, TenantLocationConfigRequest, TimelineCreateRequest,
TenantCreateRequest, TenantLocationConfigRequest, TenantShardSplitRequest,
TimelineCreateRequest,
};
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
@@ -41,7 +42,7 @@ pub struct HttpState {
impl HttpState {
pub fn new(service: Arc<crate::service::Service>, auth: Option<Arc<SwappableJwtAuth>>) -> Self {
let allowlist_routes = ["/status"]
let allowlist_routes = ["/status", "/ready", "/metrics"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
@@ -292,6 +293,19 @@ async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>,
json_response(StatusCode::OK, state.service.node_configure(config_req)?)
}
async fn handle_tenant_shard_split(
service: Arc<Service>,
mut req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let split_req = json_request::<TenantShardSplitRequest>(&mut req).await?;
json_response(
StatusCode::OK,
service.tenant_shard_split(tenant_id, split_req).await?,
)
}
async fn handle_tenant_shard_migrate(
service: Arc<Service>,
mut req: Request<Body>,
@@ -311,6 +325,17 @@ async fn handle_status(_req: Request<Body>) -> Result<Response<Body>, ApiError>
json_response(StatusCode::OK, ())
}
/// Readiness endpoint indicates when we're done doing startup I/O (e.g. reconciling
/// with remote pageserver nodes). This is intended for use as a kubernetes readiness probe.
async fn handle_ready(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let state = get_state(&req);
if state.service.startup_complete.is_ready() {
json_response(StatusCode::OK, ())
} else {
json_response(StatusCode::SERVICE_UNAVAILABLE, ())
}
}
impl From<ReconcileError> for ApiError {
fn from(value: ReconcileError) -> Self {
ApiError::Conflict(format!("Reconciliation error: {}", value))
@@ -366,6 +391,7 @@ pub fn make_router(
.data(Arc::new(HttpState::new(service, auth)))
// Non-prefixed generic endpoints (status, metrics)
.get("/status", |r| request_span(r, handle_status))
.get("/ready", |r| request_span(r, handle_ready))
// Upcalls for the pageserver: point the pageserver's `control_plane_api` config to this prefix
.post("/upcall/v1/re-attach", |r| {
request_span(r, handle_re_attach)
@@ -391,6 +417,9 @@ pub fn make_router(
.put("/control/v1/tenant/:tenant_shard_id/migrate", |r| {
tenant_service_handler(r, handle_tenant_shard_migrate)
})
.put("/control/v1/tenant/:tenant_id/shard_split", |r| {
tenant_service_handler(r, handle_tenant_shard_split)
})
// Tenant operations
// The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into
// this service to manage tenants that actually consist of many tenant shards, as if they are a single entity.

View File

@@ -170,6 +170,7 @@ impl Secrets {
}
}
/// Execute the diesel migrations that are built into this binary
async fn migration_run(database_url: &str) -> anyhow::Result<()> {
use diesel::PgConnection;
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
@@ -183,8 +184,18 @@ async fn migration_run(database_url: &str) -> anyhow::Result<()> {
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
fn main() -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
// We use spawn_blocking for database operations, so require approximately
// as many blocking threads as we will open database connections.
.max_blocking_threads(Persistence::MAX_CONNECTIONS as usize)
.enable_all()
.build()
.unwrap()
.block_on(async_main())
}
async fn async_main() -> anyhow::Result<()> {
let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
logging::init(

View File

@@ -1,6 +1,9 @@
pub(crate) mod split_state;
use std::collections::HashMap;
use std::str::FromStr;
use std::time::Duration;
use self::split_state::SplitState;
use camino::Utf8Path;
use camino::Utf8PathBuf;
use control_plane::attachment_service::{NodeAvailability, NodeSchedulingPolicy};
@@ -44,7 +47,7 @@ use crate::PlacementPolicy;
/// updated, and reads of nodes are always from memory, not the database. We only require that
/// we can UPDATE a node's scheduling mode reasonably quickly to mark a bad node offline.
pub struct Persistence {
database_url: String,
connection_pool: diesel::r2d2::Pool<diesel::r2d2::ConnectionManager<PgConnection>>,
// In test environments, we support loading+saving a JSON file. This is temporary, for the benefit of
// test_compatibility.py, so that we don't have to commit to making the database contents fully backward/forward
@@ -64,6 +67,8 @@ pub(crate) enum DatabaseError {
Query(#[from] diesel::result::Error),
#[error(transparent)]
Connection(#[from] diesel::result::ConnectionError),
#[error(transparent)]
ConnectionPool(#[from] r2d2::Error),
#[error("Logical error: {0}")]
Logical(String),
}
@@ -71,9 +76,31 @@ pub(crate) enum DatabaseError {
pub(crate) type DatabaseResult<T> = Result<T, DatabaseError>;
impl Persistence {
// The default postgres connection limit is 100. We use up to 99, to leave one free for a human admin under
// normal circumstances. This assumes we have exclusive use of the database cluster to which we connect.
pub const MAX_CONNECTIONS: u32 = 99;
// We don't want to keep a lot of connections alive: close them down promptly if they aren't being used.
const IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_CONNECTION_LIFETIME: Duration = Duration::from_secs(60);
pub fn new(database_url: String, json_path: Option<Utf8PathBuf>) -> Self {
let manager = diesel::r2d2::ConnectionManager::<PgConnection>::new(database_url);
// We will use a connection pool: this is primarily to _limit_ our connection count, rather than to optimize time
// to execute queries (database queries are not generally on latency-sensitive paths).
let connection_pool = diesel::r2d2::Pool::builder()
.max_size(Self::MAX_CONNECTIONS)
.max_lifetime(Some(Self::MAX_CONNECTION_LIFETIME))
.idle_timeout(Some(Self::IDLE_CONNECTION_TIMEOUT))
// Always keep at least one connection ready to go
.min_idle(Some(1))
.test_on_check_out(true)
.build(manager)
.expect("Could not build connection pool");
Self {
database_url,
connection_pool,
json_path,
}
}
@@ -84,14 +111,10 @@ impl Persistence {
F: Fn(&mut PgConnection) -> DatabaseResult<R> + Send + 'static,
R: Send + 'static,
{
let database_url = self.database_url.clone();
tokio::task::spawn_blocking(move || -> DatabaseResult<R> {
// TODO: connection pooling, such as via diesel::r2d2
let mut conn = PgConnection::establish(&database_url)?;
func(&mut conn)
})
.await
.expect("Task panic")
let mut conn = self.connection_pool.get()?;
tokio::task::spawn_blocking(move || -> DatabaseResult<R> { func(&mut conn) })
.await
.expect("Task panic")
}
/// When a node is first registered, persist it before using it for anything
@@ -342,19 +365,107 @@ impl Persistence {
Ok(())
}
// TODO: when we start shard splitting, we must durably mark the tenant so that
// on restart, we know that we must go through recovery (list shards that exist
// and pick up where we left off and/or revert to parent shards).
// When we start shard splitting, we must durably mark the tenant so that
// on restart, we know that we must go through recovery.
//
// We create the child shards here, so that they will be available for increment_generation calls
// if some pageserver holding a child shard needs to restart before the overall tenant split is complete.
#[allow(dead_code)]
pub(crate) async fn begin_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
pub(crate) async fn begin_shard_split(
&self,
old_shard_count: ShardCount,
split_tenant_id: TenantId,
parent_to_children: Vec<(TenantShardId, Vec<TenantShardPersistence>)>,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> DatabaseResult<()> {
// Mark parent shards as splitting
let expect_parent_records = std::cmp::max(1, old_shard_count.0);
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.set((splitting.eq(1),))
.execute(conn)?;
if u8::try_from(updated)
.map_err(|_| DatabaseError::Logical(
format!("Overflow existing shard count {} while splitting", updated))
)? != expect_parent_records {
// Perhaps a deletion or another split raced with this attempt to split, mutating
// the parent shards that we intend to split. In this case the split request should fail.
return Err(DatabaseError::Logical(
format!("Unexpected existing shard count {updated} when preparing tenant for split (expected {expect_parent_records})")
));
}
// FIXME: spurious clone to sidestep closure move rules
let parent_to_children = parent_to_children.clone();
// Insert child shards
for (parent_shard_id, children) in parent_to_children {
let mut parent = crate::schema::tenant_shards::table
.filter(tenant_id.eq(parent_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(parent_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(parent_shard_id.shard_count.0 as i32))
.load::<TenantShardPersistence>(conn)?;
let parent = if parent.len() != 1 {
return Err(DatabaseError::Logical(format!(
"Parent shard {parent_shard_id} not found"
)));
} else {
parent.pop().unwrap()
};
for mut shard in children {
// Carry the parent's generation into the child
shard.generation = parent.generation;
debug_assert!(shard.splitting == SplitState::Splitting);
diesel::insert_into(tenant_shards)
.values(shard)
.execute(conn)?;
}
}
Ok(())
})?;
Ok(())
})
.await
}
// TODO: when we finish shard splitting, we must atomically clean up the old shards
// When we finish shard splitting, we must atomically clean up the old shards
// and insert the new shards, and clear the splitting marker.
#[allow(dead_code)]
pub(crate) async fn complete_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
pub(crate) async fn complete_shard_split(
&self,
split_tenant_id: TenantId,
old_shard_count: ShardCount,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
// Drop parent shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.0 as i32))
.execute(conn)?;
// Clear sharding flag
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.set((splitting.eq(0),))
.execute(conn)?;
debug_assert!(updated > 0);
Ok(())
})?;
Ok(())
})
.await
}
}
@@ -382,6 +493,8 @@ pub(crate) struct TenantShardPersistence {
#[serde(default)]
pub(crate) placement_policy: String,
#[serde(default)]
pub(crate) splitting: SplitState,
#[serde(default)]
pub(crate) config: String,
}

View File

@@ -0,0 +1,46 @@
use diesel::pg::{Pg, PgValue};
use diesel::{
deserialize::FromSql, deserialize::FromSqlRow, expression::AsExpression, serialize::ToSql,
sql_types::Int2,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, FromSqlRow, AsExpression)]
#[diesel(sql_type = SplitStateSQLRepr)]
#[derive(Deserialize, Serialize)]
pub enum SplitState {
Idle = 0,
Splitting = 1,
}
impl Default for SplitState {
fn default() -> Self {
Self::Idle
}
}
type SplitStateSQLRepr = Int2;
impl ToSql<SplitStateSQLRepr, Pg> for SplitState {
fn to_sql<'a>(
&'a self,
out: &'a mut diesel::serialize::Output<Pg>,
) -> diesel::serialize::Result {
let raw_value: i16 = *self as i16;
let mut new_out = out.reborrow();
ToSql::<SplitStateSQLRepr, Pg>::to_sql(&raw_value, &mut new_out)
}
}
impl FromSql<SplitStateSQLRepr, Pg> for SplitState {
fn from_sql(pg_value: PgValue) -> diesel::deserialize::Result<Self> {
match FromSql::<SplitStateSQLRepr, Pg>::from_sql(pg_value).map(|v| match v {
0 => Some(Self::Idle),
1 => Some(Self::Splitting),
_ => None,
})? {
Some(v) => Ok(v),
None => Err(format!("Invalid SplitState value, was: {:?}", pg_value.as_bytes()).into()),
}
}
}

View File

@@ -20,6 +20,7 @@ diesel::table! {
generation -> Int4,
generation_pageserver -> Int8,
placement_policy -> Varchar,
splitting -> Int2,
config -> Text,
}
}

View File

@@ -1,5 +1,6 @@
use std::{
collections::{BTreeMap, HashMap},
cmp::Ordering,
collections::{BTreeMap, HashMap, HashSet},
str::FromStr,
sync::Arc,
time::{Duration, Instant},
@@ -23,13 +24,14 @@ use pageserver_api::{
models::{
LocationConfig, LocationConfigMode, ShardParameters, TenantConfig, TenantCreateRequest,
TenantLocationConfigRequest, TenantLocationConfigResponse, TenantShardLocation,
TimelineCreateRequest, TimelineInfo,
TenantShardSplitRequest, TenantShardSplitResponse, TimelineCreateRequest, TimelineInfo,
},
shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId},
};
use pageserver_client::mgmt_api;
use tokio_util::sync::CancellationToken;
use utils::{
backoff,
completion::Barrier,
generation::Generation,
http::error::ApiError,
@@ -40,7 +42,11 @@ use utils::{
use crate::{
compute_hook::{self, ComputeHook},
node::Node,
persistence::{DatabaseError, NodePersistence, Persistence, TenantShardPersistence},
persistence::{
split_state::SplitState, DatabaseError, NodePersistence, Persistence,
TenantShardPersistence,
},
reconciler::attached_location_conf,
scheduler::Scheduler,
tenant_state::{
IntentState, ObservedState, ObservedStateLocation, ReconcileResult, ReconcileWaitError,
@@ -103,7 +109,9 @@ impl From<DatabaseError> for ApiError {
match err {
DatabaseError::Query(e) => ApiError::InternalServerError(e.into()),
// FIXME: ApiError doesn't have an Unavailable variant, but ShuttingDown maps to 503.
DatabaseError::Connection(_e) => ApiError::ShuttingDown,
DatabaseError::Connection(_) | DatabaseError::ConnectionPool(_) => {
ApiError::ShuttingDown
}
DatabaseError::Logical(reason) => {
ApiError::InternalServerError(anyhow::anyhow!(reason))
}
@@ -143,31 +151,71 @@ impl Service {
// indeterminate, same as in [`ObservedStateLocation`])
let mut observed = HashMap::new();
let nodes = {
let locked = self.inner.read().unwrap();
locked.nodes.clone()
};
let mut nodes_online = HashSet::new();
// TODO: give Service a cancellation token for clean shutdown
let cancel = CancellationToken::new();
// TODO: issue these requests concurrently
for node in nodes.values() {
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
{
let nodes = {
let locked = self.inner.read().unwrap();
locked.nodes.clone()
};
for node in nodes.values() {
let http_client = reqwest::ClientBuilder::new()
.timeout(Duration::from_secs(5))
.build()
.expect("Failed to construct HTTP client");
let client = mgmt_api::Client::from_client(
http_client,
node.base_url(),
self.config.jwt_token.as_deref(),
);
tracing::info!("Scanning shards on node {}...", node.id);
match client.list_location_config().await {
Err(e) => {
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
// TODO: be more tolerant, apply a generous 5-10 second timeout with retries, in case
// pageserver is being restarted at the same time as we are
fn is_fatal(e: &mgmt_api::Error) -> bool {
use mgmt_api::Error::*;
match e {
ReceiveBody(_) | ReceiveErrorBody(_) => false,
ApiError(StatusCode::SERVICE_UNAVAILABLE, _)
| ApiError(StatusCode::GATEWAY_TIMEOUT, _)
| ApiError(StatusCode::REQUEST_TIMEOUT, _) => false,
ApiError(_, _) => true,
}
}
Ok(listing) => {
tracing::info!(
"Received {} shard statuses from pageserver {}, setting it to Active",
listing.tenant_shards.len(),
node.id
);
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
observed.insert(tenant_shard_id, (node.id, conf_opt));
let list_response = backoff::retry(
|| client.list_location_config(),
is_fatal,
1,
5,
"Location config listing",
&cancel,
)
.await;
let Some(list_response) = list_response else {
tracing::info!("Shutdown during startup_reconcile");
return;
};
tracing::info!("Scanning shards on node {}...", node.id);
match list_response {
Err(e) => {
tracing::warn!("Could not contact pageserver {} ({e})", node.id);
// TODO: be more tolerant, do some retries, in case
// pageserver is being restarted at the same time as we are
}
Ok(listing) => {
tracing::info!(
"Received {} shard statuses from pageserver {}, setting it to Active",
listing.tenant_shards.len(),
node.id
);
nodes_online.insert(node.id);
for (tenant_shard_id, conf_opt) in listing.tenant_shards {
observed.insert(tenant_shard_id, (node.id, conf_opt));
}
}
}
}
@@ -178,8 +226,19 @@ impl Service {
let mut compute_notifications = Vec::new();
// Populate intent and observed states for all tenants, based on reported state on pageservers
let shard_count = {
let (shard_count, nodes) = {
let mut locked = self.inner.write().unwrap();
// Mark nodes online if they responded to us: nodes are offline by default after a restart.
let mut nodes = (*locked.nodes).clone();
for (node_id, node) in nodes.iter_mut() {
if nodes_online.contains(node_id) {
node.availability = NodeAvailability::Active;
}
}
locked.nodes = Arc::new(nodes);
let nodes = locked.nodes.clone();
for (tenant_shard_id, (node_id, observed_loc)) in observed {
let Some(tenant_state) = locked.tenants.get_mut(&tenant_shard_id) else {
cleanup.push((tenant_shard_id, node_id));
@@ -211,7 +270,7 @@ impl Service {
}
}
locked.tenants.len()
(locked.tenants.len(), nodes)
};
// TODO: if any tenant's intent now differs from its loaded generation_pageserver, we should clear that
@@ -272,9 +331,8 @@ impl Service {
let stream = futures::stream::iter(compute_notifications.into_iter())
.map(|(tenant_shard_id, node_id)| {
let compute_hook = compute_hook.clone();
let cancel = cancel.clone();
async move {
// TODO: give Service a cancellation token for clean shutdown
let cancel = CancellationToken::new();
if let Err(e) = compute_hook.notify(tenant_shard_id, node_id, &cancel).await {
tracing::error!(
tenant_shard_id=%tenant_shard_id,
@@ -380,7 +438,7 @@ impl Service {
))),
config,
persistence,
startup_complete,
startup_complete: startup_complete.clone(),
});
let result_task_this = this.clone();
@@ -474,6 +532,7 @@ impl Service {
generation_pageserver: i64::MAX,
placement_policy: serde_json::to_string(&PlacementPolicy::default()).unwrap(),
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
splitting: SplitState::default(),
};
match self.persistence.insert_tenant_shards(vec![tsp]).await {
@@ -716,6 +775,7 @@ impl Service {
generation_pageserver: i64::MAX,
placement_policy: serde_json::to_string(&placement_policy).unwrap(),
config: serde_json::to_string(&create_req.config).unwrap(),
splitting: SplitState::default(),
})
.collect();
self.persistence
@@ -975,6 +1035,10 @@ impl Service {
}
};
// TODO: if we timeout/fail on reconcile, we should still succeed this request,
// because otherwise a broken compute hook causes a feedback loop where
// location_config returns 500 and gets retried forever.
if let Some(create_req) = maybe_create {
let create_resp = self.tenant_create(create_req).await?;
result.shards = create_resp
@@ -987,7 +1051,15 @@ impl Service {
.collect();
} else {
// This was an update, wait for reconciliation
self.await_waiters(waiters).await?;
if let Err(e) = self.await_waiters(waiters).await {
// Do not treat a reconcile error as fatal: we have already applied any requested
// Intent changes, and the reconcile can fail for external reasons like unavailable
// compute notification API. In these cases, it is important that we do not
// cause the cloud control plane to retry forever on this API.
tracing::warn!(
"Failed to reconcile after /location_config: {e}, returning success anyway"
);
}
}
Ok(result)
@@ -1090,6 +1162,7 @@ impl Service {
self.ensure_attached_wait(tenant_id).await?;
// TODO: refuse to do this if shard splitting is in progress
// (https://github.com/neondatabase/neon/issues/6676)
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
@@ -1170,6 +1243,7 @@ impl Service {
self.ensure_attached_wait(tenant_id).await?;
// TODO: refuse to do this if shard splitting is in progress
// (https://github.com/neondatabase/neon/issues/6676)
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
@@ -1342,6 +1416,326 @@ impl Service {
})
}
pub(crate) async fn tenant_shard_split(
&self,
tenant_id: TenantId,
split_req: TenantShardSplitRequest,
) -> Result<TenantShardSplitResponse, ApiError> {
let mut policy = None;
let mut shard_ident = None;
// TODO: put a cancellation token on Service for clean shutdown
let cancel = CancellationToken::new();
// A parent shard which will be split
struct SplitTarget {
parent_id: TenantShardId,
node: Node,
child_ids: Vec<TenantShardId>,
}
// Validate input, and calculate which shards we will create
let (old_shard_count, targets, compute_hook) = {
let locked = self.inner.read().unwrap();
let pageservers = locked.nodes.clone();
let mut targets = Vec::new();
// In case this is a retry, count how many already-split shards we found
let mut children_found = Vec::new();
let mut old_shard_count = None;
for (tenant_shard_id, shard) in
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
{
match shard.shard.count.0.cmp(&split_req.new_shard_count) {
Ordering::Equal => {
// Already split this
children_found.push(*tenant_shard_id);
continue;
}
Ordering::Greater => {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Requested count {} but already have shards at count {}",
split_req.new_shard_count,
shard.shard.count.0
)));
}
Ordering::Less => {
// Fall through: this shard has lower count than requested,
// is a candidate for splitting.
}
}
match old_shard_count {
None => old_shard_count = Some(shard.shard.count),
Some(old_shard_count) => {
if old_shard_count != shard.shard.count {
// We may hit this case if a caller asked for two splits to
// different sizes, before the first one is complete.
// e.g. 1->2, 2->4, where the 4 call comes while we have a mixture
// of shard_count=1 and shard_count=2 shards in the map.
return Err(ApiError::Conflict(
"Cannot split, currently mid-split".to_string(),
));
}
}
}
if policy.is_none() {
policy = Some(shard.policy.clone());
}
if shard_ident.is_none() {
shard_ident = Some(shard.shard);
}
if tenant_shard_id.shard_count == ShardCount(split_req.new_shard_count) {
tracing::info!(
"Tenant shard {} already has shard count {}",
tenant_shard_id,
split_req.new_shard_count
);
continue;
}
let node_id =
shard
.intent
.attached
.ok_or(ApiError::BadRequest(anyhow::anyhow!(
"Cannot split a tenant that is not attached"
)))?;
let node = pageservers
.get(&node_id)
.expect("Pageservers may not be deleted while referenced");
// TODO: if any reconciliation is currently in progress for this shard, wait for it.
targets.push(SplitTarget {
parent_id: *tenant_shard_id,
node: node.clone(),
child_ids: tenant_shard_id.split(ShardCount(split_req.new_shard_count)),
});
}
if targets.is_empty() {
if children_found.len() == split_req.new_shard_count as usize {
return Ok(TenantShardSplitResponse {
new_shards: children_found,
});
} else {
// No shards found to split, and no existing children found: the
// tenant doesn't exist at all.
return Err(ApiError::NotFound(
anyhow::anyhow!("Tenant {} not found", tenant_id).into(),
));
}
}
(old_shard_count, targets, locked.compute_hook.clone())
};
// unwrap safety: we would have returned above if we didn't find at least one shard to split
let old_shard_count = old_shard_count.unwrap();
let shard_ident = shard_ident.unwrap();
let policy = policy.unwrap();
// FIXME: we have dropped self.inner lock, and not yet written anything to the database: another
// request could occur here, deleting or mutating the tenant. begin_shard_split checks that the
// parent shards exist as expected, but it would be neater to do the above pre-checks within the
// same database transaction rather than pre-check in-memory and then maybe-fail the database write.
// (https://github.com/neondatabase/neon/issues/6676)
// Before creating any new child shards in memory or on the pageservers, persist them: this
// enables us to ensure that we will always be able to clean up if something goes wrong. This also
// acts as the protection against two concurrent attempts to split: one of them will get a database
// error trying to insert the child shards.
let mut child_tsps = Vec::new();
for target in &targets {
let mut this_child_tsps = Vec::new();
for child in &target.child_ids {
let mut child_shard = shard_ident;
child_shard.number = child.shard_number;
child_shard.count = child.shard_count;
this_child_tsps.push(TenantShardPersistence {
tenant_id: child.tenant_id.to_string(),
shard_number: child.shard_number.0 as i32,
shard_count: child.shard_count.0 as i32,
shard_stripe_size: shard_ident.stripe_size.0 as i32,
// Note: this generation is a placeholder, [`Persistence::begin_shard_split`] will
// populate the correct generation as part of its transaction, to protect us
// against racing with changes in the state of the parent.
generation: 0,
generation_pageserver: target.node.id.0 as i64,
placement_policy: serde_json::to_string(&policy).unwrap(),
// TODO: get the config out of the map
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
splitting: SplitState::Splitting,
});
}
child_tsps.push((target.parent_id, this_child_tsps));
}
if let Err(e) = self
.persistence
.begin_shard_split(old_shard_count, tenant_id, child_tsps)
.await
{
match e {
DatabaseError::Query(diesel::result::Error::DatabaseError(
DatabaseErrorKind::UniqueViolation,
_,
)) => {
// Inserting a child shard violated a unique constraint: we raced with another call to
// this function
tracing::warn!("Conflicting attempt to split {tenant_id}: {e}");
return Err(ApiError::Conflict("Tenant is already splitting".into()));
}
_ => return Err(ApiError::InternalServerError(e.into())),
}
}
// FIXME: we have now committed the shard split state to the database, so any subsequent
// failure needs to roll it back. We will later wrap this function in logic to roll back
// the split if it fails.
// (https://github.com/neondatabase/neon/issues/6676)
// TODO: issue split calls concurrently (this only matters once we're splitting
// N>1 shards into M shards -- initially we're usually splitting 1 shard into N).
for target in &targets {
let SplitTarget {
parent_id,
node,
child_ids,
} = target;
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let response = client
.tenant_shard_split(
*parent_id,
TenantShardSplitRequest {
new_shard_count: split_req.new_shard_count,
},
)
.await
.map_err(|e| ApiError::Conflict(format!("Failed to split {}: {}", parent_id, e)))?;
tracing::info!(
"Split {} into {}",
parent_id,
response
.new_shards
.iter()
.map(|s| format!("{:?}", s))
.collect::<Vec<_>>()
.join(",")
);
if &response.new_shards != child_ids {
// This should never happen: the pageserver should agree with us on how shard splits work.
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Splitting shard {} resulted in unexpected IDs: {:?} (expected {:?})",
parent_id,
response.new_shards,
child_ids
)));
}
}
// TODO: if the pageserver restarted concurrently with our split API call,
// the actual generation of the child shard might differ from the generation
// we expect it to have. In order for our in-database generation to end up
// correct, we should carry the child generation back in the response and apply it here
// in complete_shard_split (and apply the correct generation in memory)
// (or, we can carry generation in the request and reject the request if
// it doesn't match, but that requires more retry logic on this side)
self.persistence
.complete_shard_split(tenant_id, old_shard_count)
.await?;
// Replace all the shards we just split with their children
let mut response = TenantShardSplitResponse {
new_shards: Vec::new(),
};
let mut child_locations = Vec::new();
{
let mut locked = self.inner.write().unwrap();
for target in targets {
let SplitTarget {
parent_id,
node: _node,
child_ids,
} = target;
let (pageserver, generation, config) = {
let old_state = locked
.tenants
.remove(&parent_id)
.expect("It was present, we just split it");
(
old_state.intent.attached.unwrap(),
old_state.generation,
old_state.config.clone(),
)
};
locked.tenants.remove(&parent_id);
for child in child_ids {
let mut child_shard = shard_ident;
child_shard.number = child.shard_number;
child_shard.count = child.shard_count;
let mut child_observed: HashMap<NodeId, ObservedStateLocation> = HashMap::new();
child_observed.insert(
pageserver,
ObservedStateLocation {
conf: Some(attached_location_conf(generation, &child_shard, &config)),
},
);
let mut child_state = TenantState::new(child, child_shard, policy.clone());
child_state.intent = IntentState::single(Some(pageserver));
child_state.observed = ObservedState {
locations: child_observed,
};
child_state.generation = generation;
child_state.config = config.clone();
child_locations.push((child, pageserver));
locked.tenants.insert(child, child_state);
response.new_shards.push(child);
}
}
}
// Send compute notifications for all the new shards
let mut failed_notifications = Vec::new();
for (child_id, child_ps) in child_locations {
if let Err(e) = compute_hook.notify(child_id, child_ps, &cancel).await {
tracing::warn!("Failed to update compute of {}->{} during split, proceeding anyway to complete split ({e})",
child_id, child_ps);
failed_notifications.push(child_id);
}
}
// If we failed any compute notifications, make a note to retry later.
if !failed_notifications.is_empty() {
let mut locked = self.inner.write().unwrap();
for failed in failed_notifications {
if let Some(shard) = locked.tenants.get_mut(&failed) {
shard.pending_compute_notification = true;
}
}
}
Ok(response)
}
pub(crate) async fn tenant_shard_migrate(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -193,6 +193,13 @@ impl IntentState {
result
}
pub(crate) fn single(node_id: Option<NodeId>) -> Self {
Self {
attached: node_id,
secondary: vec![],
}
}
/// When a node goes offline, we update intents to avoid using it
/// as their attached pageserver.
///
@@ -286,6 +293,9 @@ impl TenantState {
// self.intent refers to pageservers that are offline, and pick other
// pageservers if so.
// TODO: respect the splitting bit on tenants: if they are currently splitting then we may not
// change their attach location.
// Build the set of pageservers already in use by this tenant, to avoid scheduling
// more work on the same pageservers we're already using.
let mut used_pageservers = self.intent.all_pageservers();

View File

@@ -1,20 +1,17 @@
use crate::{background_process, local_env::LocalEnv};
use camino::{Utf8Path, Utf8PathBuf};
use diesel::{
backend::Backend,
query_builder::{AstPass, QueryFragment, QueryId},
Connection, PgConnection, QueryResult, RunQueryDsl,
};
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
use hyper::Method;
use pageserver_api::{
models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo},
models::{
ShardParameters, TenantCreateRequest, TenantShardSplitRequest, TenantShardSplitResponse,
TimelineCreateRequest, TimelineInfo,
},
shard::TenantShardId,
};
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
use postgres_backend::AuthType;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{env, str::FromStr};
use std::str::FromStr;
use tokio::process::Command;
use tracing::instrument;
use url::Url;
@@ -270,37 +267,6 @@ impl AttachmentService {
.expect("non-Unicode path")
}
/// In order to access database migrations, we need to find the Neon source tree
async fn find_source_root(&self) -> anyhow::Result<Utf8PathBuf> {
// We assume that either prd or our binary is in the source tree. The former is usually
// true for automated test runners, the latter is usually true for developer workstations. Often
// both are true, which is fine.
let candidate_start_points = [
// Current working directory
Utf8PathBuf::from_path_buf(std::env::current_dir()?).unwrap(),
// Directory containing the binary we're running inside
Utf8PathBuf::from_path_buf(env::current_exe()?.parent().unwrap().to_owned()).unwrap(),
];
// For each candidate start point, search through ancestors looking for a neon.git source tree root
for start_point in &candidate_start_points {
// Start from the build dir: assumes we are running out of a built neon source tree
for path in start_point.ancestors() {
// A crude approximation: the root of the source tree is whatever contains a "control_plane"
// subdirectory.
let control_plane = path.join("control_plane");
if tokio::fs::try_exists(&control_plane).await? {
return Ok(path.to_owned());
}
}
}
// Fall-through
Err(anyhow::anyhow!(
"Could not find control_plane src dir, after searching ancestors of {candidate_start_points:?}"
))
}
/// Find the directory containing postgres binaries, such as `initdb` and `pg_ctl`
///
/// This usually uses ATTACHMENT_SERVICE_POSTGRES_VERSION of postgres, but will fall back
@@ -340,69 +306,32 @@ impl AttachmentService {
///
/// Returns the database url
pub async fn setup_database(&self) -> anyhow::Result<String> {
let database_url = format!(
"postgresql://localhost:{}/attachment_service",
self.postgres_port
);
println!("Running attachment service database setup...");
fn change_database_of_url(database_url: &str, default_database: &str) -> (String, String) {
let base = ::url::Url::parse(database_url).unwrap();
let database = base.path_segments().unwrap().last().unwrap().to_owned();
let mut new_url = base.join(default_database).unwrap();
new_url.set_query(base.query());
(database, new_url.into())
}
const DB_NAME: &str = "attachment_service";
let database_url = format!("postgresql://localhost:{}/{DB_NAME}", self.postgres_port);
#[derive(Debug, Clone)]
pub struct CreateDatabaseStatement {
db_name: String,
}
let pg_bin_dir = self.get_pg_bin_dir().await?;
let createdb_path = pg_bin_dir.join("createdb");
let output = Command::new(&createdb_path)
.args([
"-h",
"localhost",
"-p",
&format!("{}", self.postgres_port),
&DB_NAME,
])
.output()
.await
.expect("Failed to spawn createdb");
impl CreateDatabaseStatement {
pub fn new(db_name: &str) -> Self {
CreateDatabaseStatement {
db_name: db_name.to_owned(),
}
if !output.status.success() {
let stderr = String::from_utf8(output.stderr).expect("Non-UTF8 output from createdb");
if stderr.contains("already exists") {
tracing::info!("Database {DB_NAME} already exists");
} else {
anyhow::bail!("createdb failed with status {}: {stderr}", output.status);
}
}
impl<DB: Backend> QueryFragment<DB> for CreateDatabaseStatement {
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
out.push_sql("CREATE DATABASE ");
out.push_identifier(&self.db_name)?;
Ok(())
}
}
impl<Conn> RunQueryDsl<Conn> for CreateDatabaseStatement {}
impl QueryId for CreateDatabaseStatement {
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}
if PgConnection::establish(&database_url).is_err() {
let (database, postgres_url) = change_database_of_url(&database_url, "postgres");
println!("Creating database: {database}");
let mut conn = PgConnection::establish(&postgres_url)?;
CreateDatabaseStatement::new(&database).execute(&mut conn)?;
}
let mut conn = PgConnection::establish(&database_url)?;
let migrations_dir = self
.find_source_root()
.await?
.join("control_plane/attachment_service/migrations");
let migrations = diesel_migrations::FileBasedMigrations::from_path(migrations_dir)?;
println!("Running migrations in {}", migrations.path().display());
HarnessWithOutput::write_to_stdout(&mut conn)
.run_pending_migrations(migrations)
.map(|_| ())
.map_err(|e| anyhow::anyhow!(e))?;
println!("Migrations complete");
Ok(database_url)
}
@@ -648,7 +577,7 @@ impl AttachmentService {
) -> anyhow::Result<TenantShardMigrateResponse> {
self.dispatch(
Method::PUT,
format!("tenant/{tenant_shard_id}/migrate"),
format!("control/v1/tenant/{tenant_shard_id}/migrate"),
Some(TenantShardMigrateRequest {
tenant_shard_id,
node_id,
@@ -657,6 +586,20 @@ impl AttachmentService {
.await
}
#[instrument(skip(self), fields(%tenant_id, %new_shard_count))]
pub async fn tenant_split(
&self,
tenant_id: TenantId,
new_shard_count: u8,
) -> anyhow::Result<TenantShardSplitResponse> {
self.dispatch(
Method::PUT,
format!("control/v1/tenant/{tenant_id}/shard_split"),
Some(TenantShardSplitRequest { new_shard_count }),
)
.await
}
#[instrument(skip_all, fields(node_id=%req.node_id))]
pub async fn node_register(&self, req: NodeRegisterRequest) -> anyhow::Result<()> {
self.dispatch::<_, ()>(Method::POST, "control/v1/node".to_string(), Some(req))

View File

@@ -72,7 +72,6 @@ where
let log_path = datadir.join(format!("{process_name}.log"));
let process_log_file = fs::OpenOptions::new()
.create(true)
.write(true)
.append(true)
.open(&log_path)
.with_context(|| {

View File

@@ -575,6 +575,26 @@ async fn handle_tenant(
println!("{tenant_table}");
println!("{shard_table}");
}
Some(("shard-split", matches)) => {
let tenant_id = get_tenant_id(matches, env)?;
let shard_count: u8 = matches.get_one::<u8>("shard-count").cloned().unwrap_or(0);
let attachment_service = AttachmentService::from_env(env);
let result = attachment_service
.tenant_split(tenant_id, shard_count)
.await?;
println!(
"Split tenant {} into shards {}",
tenant_id,
result
.new_shards
.iter()
.map(|s| format!("{:?}", s))
.collect::<Vec<_>>()
.join(",")
);
}
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
None => bail!("no tenant subcommand provided"),
}
@@ -994,12 +1014,13 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
.get_one::<String>("endpoint_id")
.ok_or_else(|| anyhow!("No endpoint ID was provided to stop"))?;
let destroy = sub_args.get_flag("destroy");
let mode = sub_args.get_one::<String>("mode").expect("has a default");
let endpoint = cplane
.endpoints
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
endpoint.stop(destroy)?;
endpoint.stop(mode, destroy)?;
}
_ => bail!("Unexpected endpoint subcommand '{sub_name}'"),
@@ -1283,7 +1304,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) {
match ComputeControlPlane::load(env.clone()) {
Ok(cplane) => {
for (_k, node) in cplane.endpoints {
if let Err(e) = node.stop(false) {
if let Err(e) = node.stop(if immediate { "immediate" } else { "fast " }, false) {
eprintln!("postgres stop failed: {e:#}");
}
}
@@ -1524,6 +1545,11 @@ fn cli() -> Command {
.subcommand(Command::new("status")
.about("Human readable summary of the tenant's shards and attachment locations")
.arg(tenant_id_arg.clone()))
.subcommand(Command::new("shard-split")
.about("Increase the number of shards in the tenant")
.arg(tenant_id_arg.clone())
.arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)"))
)
)
.subcommand(
Command::new("pageserver")
@@ -1627,7 +1653,16 @@ fn cli() -> Command {
.long("destroy")
.action(ArgAction::SetTrue)
.required(false)
)
)
.arg(
Arg::new("mode")
.help("Postgres shutdown mode, passed to \"pg_ctl -m <mode>\"")
.long("mode")
.action(ArgAction::Set)
.required(false)
.value_parser(["smart", "fast", "immediate"])
.default_value("fast")
)
)
)

View File

@@ -761,22 +761,8 @@ impl Endpoint {
}
}
pub fn stop(&self, destroy: bool) -> Result<()> {
// If we are going to destroy data directory,
// use immediate shutdown mode, otherwise,
// shutdown gracefully to leave the data directory sane.
//
// Postgres is always started from scratch, so stop
// without destroy only used for testing and debugging.
//
self.pg_ctl(
if destroy {
&["-m", "immediate", "stop"]
} else {
&["stop"]
},
&None,
)?;
pub fn stop(&self, mode: &str, destroy: bool) -> Result<()> {
self.pg_ctl(&["-m", mode, "stop"], &None)?;
// Also wait for the compute_ctl process to die. It might have some
// cleanup work to do after postgres stops, like syncing safekeepers,

View File

@@ -90,8 +90,8 @@ pub enum ComputeFeature {
/// track short-lived connections as user activity.
ActivityMonitorExperimental,
/// Enable running migrations
Migrations,
/// Pre-install and initialize anon extension for every database in the cluster
AnonExtension,
/// This is a special feature flag that is used to represent unknown feature flags.
/// Basically all unknown to enum flags are represented as this one. See unit test

View File

@@ -13,6 +13,9 @@ twox-hash.workspace = true
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
procfs.workspace = true
[dev-dependencies]
rand = "0.8"
rand_distr = "0.4.3"

View File

@@ -31,6 +31,8 @@ pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub mod metric_vec_duration;
pub use hll::{HyperLogLog, HyperLogLogVec};
#[cfg(target_os = "linux")]
pub mod more_process_metrics;
pub type UIntGauge = GenericGauge<AtomicU64>;
pub type UIntGaugeVec = GenericGaugeVec<AtomicU64>;

View File

@@ -0,0 +1,54 @@
//! process metrics that the [`::prometheus`] crate doesn't provide.
// This module has heavy inspiration from the prometheus crate's `process_collector.rs`.
use crate::UIntGauge;
pub struct Collector {
descs: Vec<prometheus::core::Desc>,
vmlck: crate::UIntGauge,
}
const NMETRICS: usize = 1;
impl prometheus::core::Collector for Collector {
fn desc(&self) -> Vec<&prometheus::core::Desc> {
self.descs.iter().collect()
}
fn collect(&self) -> Vec<prometheus::proto::MetricFamily> {
let Ok(myself) = procfs::process::Process::myself() else {
return vec![];
};
let mut mfs = Vec::with_capacity(NMETRICS);
if let Ok(status) = myself.status() {
if let Some(vmlck) = status.vmlck {
self.vmlck.set(vmlck);
mfs.extend(self.vmlck.collect())
}
}
mfs
}
}
impl Collector {
pub fn new() -> Self {
let mut descs = Vec::new();
let vmlck =
UIntGauge::new("libmetrics_process_status_vmlck", "/proc/self/status vmlck").unwrap();
descs.extend(
prometheus::core::Collector::desc(&vmlck)
.into_iter()
.cloned(),
);
Self { descs, vmlck }
}
}
impl Default for Collector {
fn default() -> Self {
Self::new()
}
}

View File

@@ -192,6 +192,16 @@ pub struct TimelineCreateRequest {
pub pg_version: Option<u32>,
}
#[derive(Serialize, Deserialize)]
pub struct TenantShardSplitRequest {
pub new_shard_count: u8,
}
#[derive(Serialize, Deserialize)]
pub struct TenantShardSplitResponse {
pub new_shards: Vec<TenantShardId>,
}
/// Parameters that apply to all shards in a tenant. Used during tenant creation.
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
@@ -649,6 +659,27 @@ pub struct WalRedoManagerStatus {
pub pid: Option<u32>,
}
pub mod virtual_file {
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
pub enum IoEngineKind {
StdFs,
#[cfg(target_os = "linux")]
TokioEpollUring,
}
}
// Wrapped in libpq CopyData
#[derive(PartialEq, Eq, Debug)]
pub enum PagestreamFeMessage {

View File

@@ -88,12 +88,36 @@ impl TenantShardId {
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
}
/// Convenience for dropping the tenant_id and just getting the ShardIndex: this
/// is useful when logging from code that is already in a span that includes tenant ID, to
/// keep messages reasonably terse.
pub fn to_index(&self) -> ShardIndex {
ShardIndex {
shard_number: self.shard_number,
shard_count: self.shard_count,
}
}
/// Calculate the children of this TenantShardId when splitting the overall tenant into
/// the given number of shards.
pub fn split(&self, new_shard_count: ShardCount) -> Vec<TenantShardId> {
let effective_old_shard_count = std::cmp::max(self.shard_count.0, 1);
let mut child_shards = Vec::new();
for shard_number in 0..ShardNumber(new_shard_count.0).0 {
// Key mapping is based on a round robin mapping of key hash modulo shard count,
// so our child shards are the ones which the same keys would map to.
if shard_number % effective_old_shard_count == self.shard_number.0 {
child_shards.push(TenantShardId {
tenant_id: self.tenant_id,
shard_number: ShardNumber(shard_number),
shard_count: new_shard_count,
})
}
}
child_shards
}
}
/// Formatting helper
@@ -793,4 +817,108 @@ mod tests {
let shard = key_to_shard_number(ShardCount(10), DEFAULT_STRIPE_SIZE, &key);
assert_eq!(shard, ShardNumber(8));
}
#[test]
fn shard_id_split() {
let tenant_id = TenantId::generate();
let parent = TenantShardId::unsharded(tenant_id);
// Unsharded into 2
assert_eq!(
parent.split(ShardCount(2)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1)
}
]
);
// Unsharded into 4
assert_eq!(
parent.split(ShardCount(4)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(1)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(2)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(4),
shard_number: ShardNumber(3)
}
]
);
// count=1 into 2 (check this works the same as unsharded.)
let parent = TenantShardId {
tenant_id,
shard_count: ShardCount(1),
shard_number: ShardNumber(0),
};
assert_eq!(
parent.split(ShardCount(2)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(0)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1)
}
]
);
// count=2 into count=8
let parent = TenantShardId {
tenant_id,
shard_count: ShardCount(2),
shard_number: ShardNumber(1),
};
assert_eq!(
parent.split(ShardCount(8)),
vec![
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(1)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(3)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(5)
},
TenantShardId {
tenant_id,
shard_count: ShardCount(8),
shard_number: ShardNumber(7)
},
]
);
}
}

View File

@@ -191,6 +191,7 @@ impl RemoteStorage for AzureBlobStorage {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> anyhow::Result<Listing, DownloadError> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let list_prefix = prefix
@@ -223,6 +224,8 @@ impl RemoteStorage for AzureBlobStorage {
let mut response = builder.into_stream();
let mut res = Listing::default();
// NonZeroU32 doesn't support subtraction apparently
let mut max_keys = max_keys.map(|mk| mk.get());
while let Some(l) = response.next().await {
let entry = l.map_err(to_download_error)?;
let prefix_iter = entry
@@ -235,7 +238,18 @@ impl RemoteStorage for AzureBlobStorage {
.blobs
.blobs()
.map(|k| self.name_to_relative_path(&k.name));
res.keys.extend(blob_iter);
for key in blob_iter {
res.keys.push(key);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
mk -= 1;
if mk == 0 {
return Ok(res); // limit reached
}
max_keys = Some(mk);
}
}
}
Ok(res)
}

View File

@@ -13,9 +13,15 @@ mod azure_blob;
mod local_fs;
mod s3_bucket;
mod simulate_failures;
mod support;
use std::{
collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime,
collections::HashMap,
fmt::Debug,
num::{NonZeroU32, NonZeroUsize},
pin::Pin,
sync::Arc,
time::SystemTime,
};
use anyhow::{bail, Context};
@@ -154,7 +160,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
let result = self
.list(prefix, ListingMode::WithDelimiter)
.list(prefix, ListingMode::WithDelimiter, None)
.await?
.prefixes;
Ok(result)
@@ -170,8 +176,17 @@ pub trait RemoteStorage: Send + Sync + 'static {
/// whereas,
/// list_prefixes("foo/bar/") = ["cat", "dog"]
/// See `test_real_s3.rs` for more details.
async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys;
///
/// max_keys limits max number of keys returned; None means unlimited.
async fn list_files(
&self,
prefix: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
let result = self
.list(prefix, ListingMode::NoDelimiter, max_keys)
.await?
.keys;
Ok(result)
}
@@ -179,7 +194,8 @@ pub trait RemoteStorage: Send + Sync + 'static {
&self,
prefix: Option<&RemotePath>,
_mode: ListingMode,
) -> anyhow::Result<Listing, DownloadError>;
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError>;
/// Streams the local file contents into remote into the remote storage entry.
async fn upload(
@@ -269,6 +285,19 @@ impl std::fmt::Display for DownloadError {
impl std::error::Error for DownloadError {}
impl DownloadError {
/// Returns true if the error should not be retried with backoff
pub fn is_permanent(&self) -> bool {
use DownloadError::*;
match self {
BadInput(_) => true,
NotFound => true,
Cancelled => true,
Other(_) => false,
}
}
}
#[derive(Debug)]
pub enum TimeTravelError {
/// Validation or other error happened due to user input.
@@ -324,24 +353,31 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> anyhow::Result<Listing, DownloadError> {
match self {
Self::LocalFs(s) => s.list(prefix, mode).await,
Self::AwsS3(s) => s.list(prefix, mode).await,
Self::AzureBlob(s) => s.list(prefix, mode).await,
Self::Unreliable(s) => s.list(prefix, mode).await,
Self::LocalFs(s) => s.list(prefix, mode, max_keys).await,
Self::AwsS3(s) => s.list(prefix, mode, max_keys).await,
Self::AzureBlob(s) => s.list(prefix, mode, max_keys).await,
Self::Unreliable(s) => s.list(prefix, mode, max_keys).await,
}
}
// A function for listing all the files in a "directory"
// Example:
// list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"]
pub async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
//
// max_keys limits max number of keys returned; None means unlimited.
pub async fn list_files(
&self,
folder: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
match self {
Self::LocalFs(s) => s.list_files(folder).await,
Self::AwsS3(s) => s.list_files(folder).await,
Self::AzureBlob(s) => s.list_files(folder).await,
Self::Unreliable(s) => s.list_files(folder).await,
Self::LocalFs(s) => s.list_files(folder, max_keys).await,
Self::AwsS3(s) => s.list_files(folder, max_keys).await,
Self::AzureBlob(s) => s.list_files(folder, max_keys).await,
Self::Unreliable(s) => s.list_files(folder, max_keys).await,
}
}

View File

@@ -4,7 +4,9 @@
//! This storage used in tests, but can also be used in cases when a certain persistent
//! volume is mounted to the local FS.
use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin, time::SystemTime};
use std::{
borrow::Cow, future::Future, io::ErrorKind, num::NonZeroU32, pin::Pin, time::SystemTime,
};
use anyhow::{bail, ensure, Context};
use bytes::Bytes;
@@ -18,9 +20,7 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken};
use tracing::*;
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
use crate::{
Download, DownloadError, DownloadStream, Listing, ListingMode, RemotePath, TimeTravelError,
};
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath, TimeTravelError};
use super::{RemoteStorage, StorageMetadata};
@@ -164,6 +164,7 @@ impl RemoteStorage for LocalFs {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
let mut result = Listing::default();
@@ -180,6 +181,9 @@ impl RemoteStorage for LocalFs {
!path.is_dir()
})
.collect();
if let Some(max_keys) = max_keys {
result.keys.truncate(max_keys.get() as usize);
}
return Ok(result);
}
@@ -365,27 +369,33 @@ impl RemoteStorage for LocalFs {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?;
let len = source
.metadata()
.await
.context("query file length")
.map_err(DownloadError::Other)?
.len();
source
.seek(io::SeekFrom::Start(start_inclusive))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
let metadata = self
.read_storage_metadata(&target_path)
.await
.map_err(DownloadError::Other)?;
let download_stream: DownloadStream = match end_exclusive {
Some(end_exclusive) => Box::pin(ReaderStream::new(
source.take(end_exclusive - start_inclusive),
)),
None => Box::pin(ReaderStream::new(source)),
};
let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive);
let source = ReaderStream::new(source);
Ok(Download {
metadata,
last_modified: None,
etag: None,
download_stream,
download_stream: Box::pin(source),
})
} else {
Err(DownloadError::NotFound)
@@ -514,10 +524,8 @@ mod fs_tests {
use futures_util::Stream;
use std::{collections::HashMap, io::Write};
async fn read_and_assert_remote_file_contents(
async fn read_and_check_metadata(
storage: &LocalFs,
#[allow(clippy::ptr_arg)]
// have to use &Utf8PathBuf due to `storage.local_path` parameter requirements
remote_storage_path: &RemotePath,
expected_metadata: Option<&StorageMetadata>,
) -> anyhow::Result<String> {
@@ -596,7 +604,7 @@ mod fs_tests {
let upload_name = "upload_1";
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
let contents = read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
let contents = read_and_check_metadata(&storage, &upload_target, None).await?;
assert_eq!(
dummy_contents(upload_name),
contents,
@@ -618,7 +626,7 @@ mod fs_tests {
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
let full_range_download_contents =
read_and_assert_remote_file_contents(&storage, &upload_target, None).await?;
read_and_check_metadata(&storage, &upload_target, None).await?;
assert_eq!(
dummy_contents(upload_name),
full_range_download_contents,
@@ -660,6 +668,22 @@ mod fs_tests {
"Second part bytes should be returned when requested"
);
let suffix_bytes = storage
.download_byte_range(&upload_target, 13, None)
.await?
.download_stream;
let suffix_bytes = aggregate(suffix_bytes).await?;
let suffix = std::str::from_utf8(&suffix_bytes)?;
assert_eq!(upload_name, suffix);
let all_bytes = storage
.download_byte_range(&upload_target, 0, None)
.await?
.download_stream;
let all_bytes = aggregate(all_bytes).await?;
let all_bytes = std::str::from_utf8(&all_bytes)?;
assert_eq!(dummy_contents("upload_1"), all_bytes);
Ok(())
}
@@ -736,7 +760,7 @@ mod fs_tests {
upload_dummy_file(&storage, upload_name, Some(metadata.clone())).await?;
let full_range_download_contents =
read_and_assert_remote_file_contents(&storage, &upload_target, Some(&metadata)).await?;
read_and_check_metadata(&storage, &upload_target, Some(&metadata)).await?;
assert_eq!(
dummy_contents(upload_name),
full_range_download_contents,
@@ -772,12 +796,12 @@ mod fs_tests {
let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?;
let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?;
let listing = storage.list(None, ListingMode::NoDelimiter).await?;
let listing = storage.list(None, ListingMode::NoDelimiter, None).await?;
assert!(listing.prefixes.is_empty());
assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec());
// Delimiter: should only go one deep
let listing = storage.list(None, ListingMode::WithDelimiter).await?;
let listing = storage.list(None, ListingMode::WithDelimiter, None).await?;
assert_eq!(
listing.prefixes,
@@ -790,6 +814,7 @@ mod fs_tests {
.list(
Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()),
ListingMode::WithDelimiter,
None,
)
.await?;
assert_eq!(

View File

@@ -7,6 +7,7 @@
use std::{
borrow::Cow,
collections::HashMap,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{Context, Poll},
@@ -45,8 +46,9 @@ use utils::backoff;
use super::StorageMetadata;
use crate::{
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage,
S3Config, TimeTravelError, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
support::PermitCarrying, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode,
RemotePath, RemoteStorage, S3Config, TimeTravelError, MAX_KEYS_PER_DELETE,
REMOTE_STORAGE_PREFIX_SEPARATOR,
};
pub(super) mod metrics;
@@ -63,7 +65,6 @@ pub struct S3Bucket {
concurrency_limiter: ConcurrencyLimiter,
}
#[derive(Default)]
struct GetObjectRequest {
bucket: String,
key: String,
@@ -232,24 +233,8 @@ impl S3Bucket {
let started_at = ScopeGuard::into_inner(started_at);
match get_object {
Ok(object_output) => {
let metadata = object_output.metadata().cloned().map(StorageMetadata);
let etag = object_output.e_tag.clone();
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
let body = object_output.body;
let body = ByteStreamAsStream::from(body);
let body = PermitCarrying::new(permit, body);
let body = TimedDownload::new(started_at, body);
Ok(Download {
metadata,
etag,
last_modified,
download_stream: Box::pin(body),
})
}
let object_output = match get_object {
Ok(object_output) => object_output,
Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => {
// Count this in the AttemptOutcome::Ok bucket, because 404 is not
// an error: we expect to sometimes fetch an object and find it missing,
@@ -259,7 +244,7 @@ impl S3Bucket {
AttemptOutcome::Ok,
started_at,
);
Err(DownloadError::NotFound)
return Err(DownloadError::NotFound);
}
Err(e) => {
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
@@ -268,11 +253,27 @@ impl S3Bucket {
started_at,
);
Err(DownloadError::Other(
return Err(DownloadError::Other(
anyhow::Error::new(e).context("download s3 object"),
))
));
}
}
};
let metadata = object_output.metadata().cloned().map(StorageMetadata);
let etag = object_output.e_tag;
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
let body = object_output.body;
let body = ByteStreamAsStream::from(body);
let body = PermitCarrying::new(permit, body);
let body = TimedDownload::new(started_at, body);
Ok(Download {
metadata,
etag,
last_modified,
download_stream: Box::pin(body),
})
}
async fn delete_oids(
@@ -354,33 +355,6 @@ impl Stream for ByteStreamAsStream {
// sense and Stream::size_hint does not really
}
pin_project_lite::pin_project! {
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
struct PermitCarrying<S> {
permit: tokio::sync::OwnedSemaphorePermit,
#[pin]
inner: S,
}
}
impl<S> PermitCarrying<S> {
fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
Self { permit, inner }
}
}
impl<S: Stream<Item = std::io::Result<Bytes>>> Stream for PermitCarrying<S> {
type Item = <S as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pin_project_lite::pin_project! {
/// Times and tracks the outcome of the request.
struct TimedDownload<S> {
@@ -435,8 +409,11 @@ impl RemoteStorage for S3Bucket {
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
let kind = RequestKind::List;
// s3 sdk wants i32
let mut max_keys = max_keys.map(|mk| mk.get() as i32);
let mut result = Listing::default();
// get the passed prefix or if it is not set use prefix_in_bucket value
@@ -460,13 +437,20 @@ impl RemoteStorage for S3Bucket {
let _guard = self.permit(kind).await;
let started_at = start_measuring_requests(kind);
// min of two Options, returning Some if one is value and another is
// None (None is smaller than anything, so plain min doesn't work).
let request_max_keys = self
.max_keys_per_list_response
.into_iter()
.chain(max_keys.into_iter())
.min();
let mut request = self
.client
.list_objects_v2()
.bucket(self.bucket_name.clone())
.set_prefix(list_prefix.clone())
.set_continuation_token(continuation_token)
.set_max_keys(self.max_keys_per_list_response);
.set_max_keys(request_max_keys);
if let ListingMode::WithDelimiter = mode {
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
@@ -496,6 +480,14 @@ impl RemoteStorage for S3Bucket {
let object_path = object.key().expect("response does not contain a key");
let remote_path = self.s3_object_to_relative_path(object_path);
result.keys.push(remote_path);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
mk -= 1;
if mk == 0 {
return Ok(result); // limit reached
}
max_keys = Some(mk);
}
}
result.prefixes.extend(

View File

@@ -4,6 +4,7 @@
use bytes::Bytes;
use futures::stream::Stream;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Mutex;
use std::time::SystemTime;
use std::{collections::hash_map::Entry, sync::Arc};
@@ -60,7 +61,7 @@ impl UnreliableWrapper {
/// On the first attempts of this operation, return an error. After 'attempts_to_fail'
/// attempts, let the operation go ahead, and clear the counter.
///
fn attempt(&self, op: RemoteOp) -> Result<u64, DownloadError> {
fn attempt(&self, op: RemoteOp) -> anyhow::Result<u64> {
let mut attempts = self.attempts.lock().unwrap();
match attempts.entry(op) {
@@ -78,13 +79,13 @@ impl UnreliableWrapper {
} else {
let error =
anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
Err(DownloadError::Other(error))
Err(error)
}
}
Entry::Vacant(e) => {
let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key());
e.insert(1);
Err(DownloadError::Other(error))
Err(error)
}
}
}
@@ -105,22 +106,30 @@ impl RemoteStorage for UnreliableWrapper {
&self,
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
.map_err(DownloadError::Other)?;
self.inner.list_prefixes(prefix).await
}
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))?;
self.inner.list_files(folder).await
async fn list_files(
&self,
folder: Option<&RemotePath>,
max_keys: Option<NonZeroU32>,
) -> Result<Vec<RemotePath>, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(folder.cloned()))
.map_err(DownloadError::Other)?;
self.inner.list_files(folder, max_keys).await
}
async fn list(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
) -> Result<Listing, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
self.inner.list(prefix, mode).await
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
.map_err(DownloadError::Other)?;
self.inner.list(prefix, mode, max_keys).await
}
async fn upload(
@@ -137,7 +146,8 @@ impl RemoteStorage for UnreliableWrapper {
}
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
self.attempt(RemoteOp::Download(from.clone()))?;
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner.download(from).await
}
@@ -150,7 +160,8 @@ impl RemoteStorage for UnreliableWrapper {
// Note: We treat any download_byte_range as an "attempt" of the same
// operation. We don't pay attention to the ranges. That's good enough
// for now.
self.attempt(RemoteOp::Download(from.clone()))?;
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner
.download_byte_range(from, start_inclusive, end_exclusive)
.await
@@ -193,7 +204,7 @@ impl RemoteStorage for UnreliableWrapper {
cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
.map_err(|e| TimeTravelError::Other(anyhow::Error::new(e)))?;
.map_err(TimeTravelError::Other)?;
self.inner
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await

View File

@@ -0,0 +1,33 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures_util::Stream;
pin_project_lite::pin_project! {
/// An `AsyncRead` adapter which carries a permit for the lifetime of the value.
pub(crate) struct PermitCarrying<S> {
permit: tokio::sync::OwnedSemaphorePermit,
#[pin]
inner: S,
}
}
impl<S> PermitCarrying<S> {
pub(crate) fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self {
Self { permit, inner }
}
}
impl<S: Stream> Stream for PermitCarrying<S> {
type Item = <S as Stream>::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

View File

@@ -1,8 +1,8 @@
use anyhow::Context;
use camino::Utf8Path;
use remote_storage::RemotePath;
use std::collections::HashSet;
use std::sync::Arc;
use std::{collections::HashSet, num::NonZeroU32};
use test_context::test_context;
use tracing::debug;
@@ -103,7 +103,7 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.list_files(None, None)
.await
.context("client list root files failure")?
.into_iter()
@@ -113,8 +113,17 @@ async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> a
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
// Test that max_keys limit works. In total there are about 21 files (see
// upload_simple_remote_data call in test_real_s3.rs).
let limited_root_files = test_client
.list_files(None, Some(NonZeroU32::new(2).unwrap()))
.await
.context("client list root files failure")?;
assert_eq!(limited_root_files.len(), 2);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.list_files(Some(&base_prefix), None)
.await
.context("client list nested files failure")?
.into_iter()

View File

@@ -70,7 +70,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
}
async fn list_files(client: &Arc<GenericRemoteStorage>) -> anyhow::Result<HashSet<RemotePath>> {
Ok(retry(|| client.list_files(None))
Ok(retry(|| client.list_files(None, None))
.await
.context("list root files failure")?
.into_iter()

View File

@@ -27,6 +27,11 @@ impl Barrier {
b.wait().await
}
}
/// Return true if a call to wait() would complete immediately
pub fn is_ready(&self) -> bool {
futures::future::FutureExt::now_or_never(self.0.wait()).is_some()
}
}
impl PartialEq for Barrier {

View File

@@ -1,6 +1,6 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Arc, Mutex, MutexGuard,
};
use tokio::sync::Semaphore;
@@ -12,7 +12,7 @@ use tokio::sync::Semaphore;
///
/// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
pub struct OnceCell<T> {
inner: tokio::sync::RwLock<Inner<T>>,
inner: Mutex<Inner<T>>,
initializers: AtomicUsize,
}
@@ -50,7 +50,7 @@ impl<T> OnceCell<T> {
let sem = Semaphore::new(1);
sem.close();
Self {
inner: tokio::sync::RwLock::new(Inner {
inner: Mutex::new(Inner {
init_semaphore: Arc::new(sem),
value: Some(value),
}),
@@ -61,18 +61,18 @@ impl<T> OnceCell<T> {
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
/// returning the guard.
///
/// Initializing might wait on any existing [`GuardMut::take_and_deinit`] deinitialization.
/// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
///
/// Initialization is panic-safe and cancellation-safe.
pub async fn get_mut_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardMut<'_, T>, E>
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
where
F: FnOnce(InitPermit) -> Fut,
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
{
let sem = {
let guard = self.inner.write().await;
let guard = self.inner.lock().unwrap();
if guard.value.is_some() {
return Ok(GuardMut(guard));
return Ok(Guard(guard));
}
guard.init_semaphore.clone()
};
@@ -88,72 +88,29 @@ impl<T> OnceCell<T> {
let permit = InitPermit(permit);
let (value, _permit) = factory(permit).await?;
let guard = self.inner.write().await;
let guard = self.inner.lock().unwrap();
Ok(Self::set0(value, guard))
}
Err(_closed) => {
let guard = self.inner.write().await;
let guard = self.inner.lock().unwrap();
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(GuardMut(guard));
return Ok(Guard(guard));
}
}
}
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
/// returning the guard.
///
/// Initialization is panic-safe and cancellation-safe.
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<GuardRef<'_, T>, E>
where
F: FnOnce(InitPermit) -> Fut,
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
{
let sem = {
let guard = self.inner.read().await;
if guard.value.is_some() {
return Ok(GuardRef(guard));
}
guard.init_semaphore.clone()
};
let permit = {
// increment the count for the duration of queued
let _guard = CountWaitingInitializers::start(self);
sem.acquire_owned().await
};
match permit {
Ok(permit) => {
let permit = InitPermit(permit);
let (value, _permit) = factory(permit).await?;
let guard = self.inner.write().await;
Ok(Self::set0(value, guard).downgrade())
}
Err(_closed) => {
let guard = self.inner.read().await;
assert!(
guard.value.is_some(),
"semaphore got closed, must be initialized"
);
return Ok(GuardRef(guard));
}
}
}
/// Assuming a permit is held after previous call to [`GuardMut::take_and_deinit`], it can be used
/// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
/// to complete initializing the inner value.
///
/// # Panics
///
/// If the inner has already been initialized.
pub async fn set(&self, value: T, _permit: InitPermit) -> GuardMut<'_, T> {
let guard = self.inner.write().await;
pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
let guard = self.inner.lock().unwrap();
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
// give more permits right now.
@@ -165,31 +122,21 @@ impl<T> OnceCell<T> {
Self::set0(value, guard)
}
fn set0(value: T, mut guard: tokio::sync::RwLockWriteGuard<'_, Inner<T>>) -> GuardMut<'_, T> {
fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
if guard.value.is_some() {
drop(guard);
unreachable!("we won permit, must not be initialized");
}
guard.value = Some(value);
guard.init_semaphore.close();
GuardMut(guard)
Guard(guard)
}
/// Returns a guard to an existing initialized value, if any.
pub async fn get_mut(&self) -> Option<GuardMut<'_, T>> {
let guard = self.inner.write().await;
pub fn get(&self) -> Option<Guard<'_, T>> {
let guard = self.inner.lock().unwrap();
if guard.value.is_some() {
Some(GuardMut(guard))
} else {
None
}
}
/// Returns a guard to an existing initialized value, if any.
pub async fn get(&self) -> Option<GuardRef<'_, T>> {
let guard = self.inner.read().await;
if guard.value.is_some() {
Some(GuardRef(guard))
Some(Guard(guard))
} else {
None
}
@@ -221,9 +168,9 @@ impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
/// Uninteresting guard object to allow short-lived access to inspect or clone the held,
/// initialized value.
#[derive(Debug)]
pub struct GuardMut<'a, T>(tokio::sync::RwLockWriteGuard<'a, Inner<T>>);
pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
impl<T> std::ops::Deref for GuardMut<'_, T> {
impl<T> std::ops::Deref for Guard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -234,7 +181,7 @@ impl<T> std::ops::Deref for GuardMut<'_, T> {
}
}
impl<T> std::ops::DerefMut for GuardMut<'_, T> {
impl<T> std::ops::DerefMut for Guard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
.value
@@ -243,7 +190,7 @@ impl<T> std::ops::DerefMut for GuardMut<'_, T> {
}
}
impl<'a, T> GuardMut<'a, T> {
impl<'a, T> Guard<'a, T> {
/// Take the current value, and a new permit for it's deinitialization.
///
/// The permit will be on a semaphore part of the new internal value, and any following
@@ -261,24 +208,6 @@ impl<'a, T> GuardMut<'a, T> {
.map(|v| (v, InitPermit(permit)))
.expect("guard is not created unless value has been initialized")
}
pub fn downgrade(self) -> GuardRef<'a, T> {
GuardRef(self.0.downgrade())
}
}
#[derive(Debug)]
pub struct GuardRef<'a, T>(tokio::sync::RwLockReadGuard<'a, Inner<T>>);
impl<T> std::ops::Deref for GuardRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0
.value
.as_ref()
.expect("guard is not created unless value has been initialized")
}
}
/// Type held by OnceCell (de)initializing task.
@@ -319,7 +248,7 @@ mod tests {
barrier.wait().await;
let won = {
let g = cell
.get_mut_or_init(|permit| {
.get_or_init(|permit| {
counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
async {
counters.future_polled.fetch_add(1, Ordering::Relaxed);
@@ -366,11 +295,7 @@ mod tests {
let cell = cell.clone();
let deinitialization_started = deinitialization_started.clone();
async move {
let (answer, _permit) = cell
.get_mut()
.await
.expect("initialized to value")
.take_and_deinit();
let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
assert_eq!(answer, initial);
deinitialization_started.wait().await;
@@ -381,7 +306,7 @@ mod tests {
deinitialization_started.wait().await;
let started_at = tokio::time::Instant::now();
cell.get_mut_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
.await
.unwrap();
@@ -393,21 +318,21 @@ mod tests {
jh.await.unwrap();
assert_eq!(*cell.get_mut().await.unwrap(), reinit);
assert_eq!(*cell.get().unwrap(), reinit);
}
#[tokio::test]
async fn reinit_with_deinit_permit() {
#[test]
fn reinit_with_deinit_permit() {
let cell = Arc::new(OnceCell::new(42));
let (mol, permit) = cell.get_mut().await.unwrap().take_and_deinit();
cell.set(5, permit).await;
assert_eq!(*cell.get_mut().await.unwrap(), 5);
let (mol, permit) = cell.get().unwrap().take_and_deinit();
cell.set(5, permit);
assert_eq!(*cell.get().unwrap(), 5);
let (five, permit) = cell.get_mut().await.unwrap().take_and_deinit();
let (five, permit) = cell.get().unwrap().take_and_deinit();
assert_eq!(5, five);
cell.set(mol, permit).await;
assert_eq!(*cell.get_mut().await.unwrap(), 42);
cell.set(mol, permit);
assert_eq!(*cell.get().unwrap(), 42);
}
#[tokio::test]
@@ -415,13 +340,13 @@ mod tests {
let cell = OnceCell::default();
for _ in 0..10 {
cell.get_mut_or_init(|_permit| async { Err("whatever error") })
cell.get_or_init(|_permit| async { Err("whatever error") })
.await
.unwrap_err();
}
let g = cell
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
.get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
.await
.unwrap();
assert_eq!(*g, "finally success");
@@ -433,7 +358,7 @@ mod tests {
let barrier = tokio::sync::Barrier::new(2);
let initializer = cell.get_mut_or_init(|permit| async {
let initializer = cell.get_or_init(|permit| async {
barrier.wait().await;
futures::future::pending::<()>().await;
@@ -447,10 +372,10 @@ mod tests {
// now initializer is dropped
assert!(cell.get_mut().await.is_none());
assert!(cell.get().is_none());
let g = cell
.get_mut_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
.get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
.await
.unwrap();
assert_eq!(*g, "now initialized");

View File

@@ -453,9 +453,12 @@ mod tests {
event_mask: 0,
}),
expected_messages: vec![
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160001, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
// TODO: When updating Postgres versions, this test will cause
// problems. Postgres version in message needs updating.
//
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160002, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110,
147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147,
188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1,

View File

@@ -56,10 +56,18 @@ pub enum ForceAwaitLogicalSize {
impl Client {
pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt)
}
pub fn from_client(
client: reqwest::Client,
mgmt_api_endpoint: String,
jwt: Option<&str>,
) -> Self {
Self {
mgmt_api_endpoint,
authorization_header: jwt.map(|jwt| format!("Bearer {jwt}")),
client: reqwest::Client::new(),
client,
}
}
@@ -310,6 +318,22 @@ impl Client {
.map_err(Error::ReceiveBody)
}
pub async fn tenant_shard_split(
&self,
tenant_shard_id: TenantShardId,
req: TenantShardSplitRequest,
) -> Result<TenantShardSplitResponse> {
let uri = format!(
"{}/v1/tenant/{}/shard_split",
self.mgmt_api_endpoint, tenant_shard_id
);
self.request(Method::PUT, &uri, req)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_list(
&self,
tenant_shard_id: &TenantShardId,
@@ -339,4 +363,16 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn put_io_engine(
&self,
engine: &pageserver_api::models::virtual_file::IoEngineKind,
) -> Result<()> {
let uri = format!("{}/v1/io_engine", self.mgmt_api_endpoint);
self.request(Method::PUT, uri, engine)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -142,7 +142,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> {
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
// Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree.
pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let mut total_delta_layers = 0usize;

View File

@@ -59,7 +59,7 @@ pub(crate) enum LayerCmd {
async fn read_delta_file(path: impl AsRef<Path>, ctx: &RequestContext) -> Result<()> {
let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path");
virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
page_cache::init(100);
let file = FileBlockReader::new(VirtualFile::open(path).await?);
let summary_blk = file.read_blk(0, ctx).await?;
@@ -187,7 +187,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> {
new_tenant_id,
new_timeline_id,
} => {
pageserver::virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
pageserver::page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);

View File

@@ -123,7 +123,7 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> {
async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> {
// Basic initialization of things that don't change after startup
virtual_file::init(10, virtual_file::IoEngineKind::StdFs);
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs);
page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
dump_layerfile_from_path(path, true, &ctx).await

View File

@@ -51,6 +51,10 @@ pub(crate) struct Args {
/// It doesn't get invalidated if the keyspace changes under the hood, e.g., due to new ingested data or compaction.
#[clap(long)]
keyspace_cache: Option<Utf8PathBuf>,
/// Before starting the benchmark, live-reconfigure the pageserver to use the given
/// [`pageserver_api::models::virtual_file::IoEngineKind`].
#[clap(long)]
set_io_engine: Option<pageserver_api::models::virtual_file::IoEngineKind>,
targets: Option<Vec<TenantTimelineId>>,
}
@@ -109,6 +113,10 @@ async fn main_impl(
args.pageserver_jwt.as_deref(),
));
if let Some(engine_str) = &args.set_io_engine {
mgmt_api_client.put_io_engine(engine_str).await?;
}
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,

View File

@@ -272,6 +272,12 @@ fn start_pageserver(
);
set_build_info_metric(GIT_VERSION, BUILD_TAG);
set_launch_timestamp_metric(launch_ts);
#[cfg(target_os = "linux")]
metrics::register_internal(Box::new(metrics::more_process_metrics::Collector::new())).unwrap();
metrics::register_internal(Box::new(
pageserver::metrics::tokio_epoll_uring::Collector::new(),
))
.unwrap();
pageserver::preinitialize_metrics();
// If any failpoints were set from FAILPOINTS environment variable,

View File

@@ -623,6 +623,7 @@ impl std::fmt::Display for EvictionLayer {
}
}
#[derive(Default)]
pub(crate) struct DiskUsageEvictionInfo {
/// Timeline's largest layer (remote or resident)
pub max_layer_size: Option<u64>,
@@ -854,19 +855,27 @@ async fn collect_eviction_candidates(
let total = tenant_candidates.len();
for (i, mut candidate) in tenant_candidates.into_iter().enumerate() {
// as we iterate this reverse sorted list, the most recently accessed layer will always
// be 1.0; this is for us to evict it last.
candidate.relative_last_activity = eviction_order.relative_last_activity(total, i);
let tenant_candidates =
tenant_candidates
.into_iter()
.enumerate()
.map(|(i, mut candidate)| {
// as we iterate this reverse sorted list, the most recently accessed layer will always
// be 1.0; this is for us to evict it last.
candidate.relative_last_activity =
eviction_order.relative_last_activity(total, i);
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
} else {
MinResidentSizePartition::Below
};
cumsum += i128::from(candidate.layer.get_file_size());
candidates.push((partition, candidate));
}
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
} else {
MinResidentSizePartition::Below
};
cumsum += i128::from(candidate.layer.get_file_size());
(partition, candidate)
});
candidates.extend(tenant_candidates);
}
// Note: the same tenant ID might be hit twice, if it transitions from attached to
@@ -882,21 +891,41 @@ async fn collect_eviction_candidates(
);
for secondary_tenant in secondary_tenants {
let mut layer_info = secondary_tenant.get_layers_for_eviction();
// for secondary tenants we use a sum of on_disk layers and already evicted layers. this is
// to prevent repeated disk usage based evictions from completely draining less often
// updating secondaries.
let (mut layer_info, total_layers) = secondary_tenant.get_layers_for_eviction();
debug_assert!(
total_layers >= layer_info.resident_layers.len(),
"total_layers ({total_layers}) must be at least the resident_layers.len() ({})",
layer_info.resident_layers.len()
);
layer_info
.resident_layers
.sort_unstable_by_key(|layer_info| std::cmp::Reverse(layer_info.last_activity_ts));
candidates.extend(layer_info.resident_layers.into_iter().map(|candidate| {
(
// Secondary locations' layers are always considered above the min resident size,
// i.e. secondary locations are permitted to be trimmed to zero layers if all
// the layers have sufficiently old access times.
MinResidentSizePartition::Above,
candidate,
)
}));
let tenant_candidates =
layer_info
.resident_layers
.into_iter()
.enumerate()
.map(|(i, mut candidate)| {
candidate.relative_last_activity =
eviction_order.relative_last_activity(total_layers, i);
(
// Secondary locations' layers are always considered above the min resident size,
// i.e. secondary locations are permitted to be trimmed to zero layers if all
// the layers have sufficiently old access times.
MinResidentSizePartition::Above,
candidate,
)
});
candidates.extend(tenant_candidates);
tokio::task::yield_now().await;
}
debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below,

View File

@@ -19,11 +19,14 @@ use pageserver_api::models::ShardParameters;
use pageserver_api::models::TenantDetails;
use pageserver_api::models::TenantLocationConfigResponse;
use pageserver_api::models::TenantShardLocation;
use pageserver_api::models::TenantShardSplitRequest;
use pageserver_api::models::TenantShardSplitResponse;
use pageserver_api::models::TenantState;
use pageserver_api::models::{
DownloadRemoteLayersTaskSpawnRequest, LocationConfigMode, TenantAttachRequest,
TenantLoadRequest, TenantLocationConfigRequest,
};
use pageserver_api::shard::ShardCount;
use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use remote_storage::TimeTravelError;
@@ -875,7 +878,7 @@ async fn tenant_reset_handler(
let state = get_state(&request);
state
.tenant_manager
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), ctx)
.reset_tenant(tenant_shard_id, drop_cache.unwrap_or(false), &ctx)
.await
.map_err(ApiError::InternalServerError)?;
@@ -1104,6 +1107,25 @@ async fn tenant_size_handler(
)
}
async fn tenant_shard_split_handler(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let req: TenantShardSplitRequest = json_request(&mut request).await?;
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
let state = get_state(&request);
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
let new_shards = state
.tenant_manager
.shard_split(tenant_shard_id, ShardCount(req.new_shard_count), &ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, TenantShardSplitResponse { new_shards })
}
async fn layer_map_info_handler(
request: Request<Body>,
_cancel: CancellationToken,
@@ -1908,6 +1930,15 @@ async fn post_tracing_event_handler(
json_response(StatusCode::OK, ())
}
async fn put_io_engine_handler(
mut r: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let kind: crate::virtual_file::IoEngineKind = json_request(&mut r).await?;
crate::virtual_file::io_engine::set(kind);
json_response(StatusCode::OK, ())
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -2054,6 +2085,9 @@ pub fn make_router(
.put("/v1/tenant/config", |r| {
api_handler(r, update_tenant_config_handler)
})
.put("/v1/tenant/:tenant_shard_id/shard_split", |r| {
api_handler(r, tenant_shard_split_handler)
})
.get("/v1/tenant/:tenant_shard_id/config", |r| {
api_handler(r, get_tenant_config_handler)
})
@@ -2165,5 +2199,6 @@ pub fn make_router(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/keyspace",
|r| testing_api_handler("read out the keyspace", r, timeline_collect_keyspace),
)
.put("/v1/io_engine", |r| api_handler(r, put_io_engine_handler))
.any(handler_404))
}

View File

@@ -2400,6 +2400,72 @@ impl<F: Future<Output = Result<O, E>>, O, E> Future for MeasuredRemoteOp<F> {
}
}
pub mod tokio_epoll_uring {
use metrics::UIntGauge;
pub struct Collector {
descs: Vec<metrics::core::Desc>,
systems_created: UIntGauge,
systems_destroyed: UIntGauge,
}
const NMETRICS: usize = 2;
impl metrics::core::Collector for Collector {
fn desc(&self) -> Vec<&metrics::core::Desc> {
self.descs.iter().collect()
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut mfs = Vec::with_capacity(NMETRICS);
let tokio_epoll_uring::metrics::Metrics {
systems_created,
systems_destroyed,
} = tokio_epoll_uring::metrics::global();
self.systems_created.set(systems_created);
mfs.extend(self.systems_created.collect());
self.systems_destroyed.set(systems_destroyed);
mfs.extend(self.systems_destroyed.collect());
mfs
}
}
impl Collector {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut descs = Vec::new();
let systems_created = UIntGauge::new(
"pageserver_tokio_epoll_uring_systems_created",
"counter of tokio-epoll-uring systems that were created",
)
.unwrap();
descs.extend(
metrics::core::Collector::desc(&systems_created)
.into_iter()
.cloned(),
);
let systems_destroyed = UIntGauge::new(
"pageserver_tokio_epoll_uring_systems_destroyed",
"counter of tokio-epoll-uring systems that were destroyed",
)
.unwrap();
descs.extend(
metrics::core::Collector::desc(&systems_destroyed)
.into_iter()
.cloned(),
);
Self {
descs,
systems_created,
systems_destroyed,
}
}
}
}
pub fn preinitialize_metrics() {
// Python tests need these and on some we do alerting.
//

View File

@@ -576,8 +576,8 @@ pub fn shutdown_token() -> CancellationToken {
/// Has the current task been requested to shut down?
pub fn is_shutdown_requested() -> bool {
if let Ok(cancel) = SHUTDOWN_TOKEN.try_with(|t| t.clone()) {
cancel.is_cancelled()
if let Ok(true_or_false) = SHUTDOWN_TOKEN.try_with(|t| t.is_cancelled()) {
true_or_false
} else {
if !cfg!(test) {
warn!("is_shutdown_requested() called in an unexpected task or thread");

View File

@@ -53,6 +53,7 @@ use self::metadata::TimelineMetadata;
use self::mgr::GetActiveTenantError;
use self::mgr::GetTenantError;
use self::mgr::TenantsMap;
use self::remote_timeline_client::upload::upload_index_part;
use self::remote_timeline_client::RemoteTimelineClient;
use self::timeline::uninit::TimelineExclusionError;
use self::timeline::uninit::TimelineUninitMark;
@@ -1376,7 +1377,7 @@ impl Tenant {
async move {
debug!("starting index part download");
let index_part = client.download_index_file(cancel_clone).await;
let index_part = client.download_index_file(&cancel_clone).await;
debug!("finished index part download");
@@ -2397,6 +2398,67 @@ impl Tenant {
pub(crate) fn get_generation(&self) -> Generation {
self.generation
}
/// This function partially shuts down the tenant (it shuts down the Timelines) and is fallible,
/// and can leave the tenant in a bad state if it fails. The caller is responsible for
/// resetting this tenant to a valid state if we fail.
pub(crate) async fn split_prepare(
&self,
child_shards: &Vec<TenantShardId>,
) -> anyhow::Result<()> {
let timelines = self.timelines.lock().unwrap().clone();
for timeline in timelines.values() {
let Some(tl_client) = &timeline.remote_client else {
anyhow::bail!("Remote storage is mandatory");
};
let Some(remote_storage) = &self.remote_storage else {
anyhow::bail!("Remote storage is mandatory");
};
// We do not block timeline creation/deletion during splits inside the pageserver: it is up to higher levels
// to ensure that they do not start a split if currently in the process of doing these.
// Upload an index from the parent: this is partly to provide freshness for the
// child tenants that will copy it, and partly for general ease-of-debugging: there will
// always be a parent shard index in the same generation as we wrote the child shard index.
tl_client.schedule_index_upload_for_file_changes()?;
tl_client.wait_completion().await?;
// Shut down the timeline's remote client: this means that the indices we write
// for child shards will not be invalidated by the parent shard deleting layers.
tl_client.shutdown().await?;
// Download methods can still be used after shutdown, as they don't flow through the remote client's
// queue. In principal the RemoteTimelineClient could provide this without downloading it, but this
// operation is rare, so it's simpler to just download it (and robustly guarantees that the index
// we use here really is the remotely persistent one).
let result = tl_client
.download_index_file(&self.cancel)
.instrument(info_span!("download_index_file", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id))
.await?;
let index_part = match result {
MaybeDeletedIndexPart::Deleted(_) => {
anyhow::bail!("Timeline deletion happened concurrently with split")
}
MaybeDeletedIndexPart::IndexPart(p) => p,
};
for child_shard in child_shards {
upload_index_part(
remote_storage,
child_shard,
&timeline.timeline_id,
self.generation,
&index_part,
&self.cancel,
)
.await?;
}
}
Ok(())
}
}
/// Given a Vec of timelines and their ancestors (timeline_id, ancestor_id),
@@ -3732,6 +3794,10 @@ impl Tenant {
Ok(())
}
pub(crate) fn get_tenant_conf(&self) -> TenantConfOpt {
self.tenant_conf.read().unwrap().tenant_conf
}
}
fn remove_timeline_and_uninit_mark(

View File

@@ -2,6 +2,7 @@
//! page server.
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use itertools::Itertools;
use pageserver_api::key::Key;
use pageserver_api::models::ShardParameters;
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, TenantShardId};
@@ -22,7 +23,7 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use remote_storage::GenericRemoteStorage;
use utils::crashsafe;
use utils::{completion, crashsafe};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
@@ -644,8 +645,6 @@ pub(crate) async fn shutdown_all_tenants() {
}
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
use utils::completion;
let mut join_set = JoinSet::new();
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
@@ -1200,7 +1199,7 @@ impl TenantManager {
&self,
tenant_shard_id: TenantShardId,
drop_cache: bool,
ctx: RequestContext,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let Some(old_slot) = slot_guard.get_old_value() else {
@@ -1253,7 +1252,7 @@ impl TenantManager {
None,
self.tenants,
SpawnMode::Normal,
&ctx,
ctx,
)?;
slot_guard.upsert(TenantSlot::Attached(tenant))?;
@@ -1375,6 +1374,164 @@ impl TenantManager {
slot_guard.revert();
result
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), new_shard_count=%new_shard_count.0))]
pub(crate) async fn shard_split(
&self,
tenant_shard_id: TenantShardId,
new_shard_count: ShardCount,
ctx: &RequestContext,
) -> anyhow::Result<Vec<TenantShardId>> {
let tenant = get_tenant(tenant_shard_id, true)?;
// Plan: identify what the new child shards will be
let effective_old_shard_count = std::cmp::max(tenant_shard_id.shard_count.0, 1);
if new_shard_count <= ShardCount(effective_old_shard_count) {
anyhow::bail!("Requested shard count is not an increase");
}
let expansion_factor = new_shard_count.0 / effective_old_shard_count;
if !expansion_factor.is_power_of_two() {
anyhow::bail!("Requested split is not a power of two");
}
let parent_shard_identity = tenant.shard_identity;
let parent_tenant_conf = tenant.get_tenant_conf();
let parent_generation = tenant.generation;
let child_shards = tenant_shard_id.split(new_shard_count);
tracing::info!(
"Shard {} splits into: {}",
tenant_shard_id.to_index(),
child_shards
.iter()
.map(|id| format!("{}", id.to_index()))
.join(",")
);
// Phase 1: Write out child shards' remote index files, in the parent tenant's current generation
if let Err(e) = tenant.split_prepare(&child_shards).await {
// If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might
// have been left in a partially-shut-down state.
tracing::warn!("Failed to prepare for split: {e}, reloading Tenant before returning");
self.reset_tenant(tenant_shard_id, false, ctx).await?;
return Err(e);
}
self.resources.deletion_queue_client.flush_advisory();
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
drop(tenant);
let mut parent_slot_guard =
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let parent = match parent_slot_guard.get_old_value() {
Some(TenantSlot::Attached(t)) => t,
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
Some(TenantSlot::InProgress(_)) => {
// tenant_map_acquire_slot never returns InProgress, if a slot was InProgress
// it would return an error.
unreachable!()
}
None => {
// We don't actually need the parent shard to still be attached to do our work, but it's
// a weird enough situation that the caller probably didn't want us to continue working
// if they had detached the tenant they requested the split on.
anyhow::bail!("Detached parent shard in the middle of split!")
}
};
// TODO: hardlink layers from the parent into the child shard directories so that they don't immediately re-download
// TODO: erase the dentries from the parent
// Take a snapshot of where the parent's WAL ingest had got to: we will wait for
// child shards to reach this point.
let mut target_lsns = HashMap::new();
for timeline in parent.timelines.lock().unwrap().clone().values() {
target_lsns.insert(timeline.timeline_id, timeline.get_last_record_lsn());
}
// TODO: we should have the parent shard stop its WAL ingest here, it's a waste of resources
// and could slow down the children trying to catch up.
// Phase 3: Spawn the child shards
for child_shard in &child_shards {
let mut child_shard_identity = parent_shard_identity;
child_shard_identity.count = child_shard.shard_count;
child_shard_identity.number = child_shard.shard_number;
let child_location_conf = LocationConf {
mode: LocationMode::Attached(AttachedLocationConfig {
generation: parent_generation,
attach_mode: AttachmentMode::Single,
}),
shard: child_shard_identity,
tenant_conf: parent_tenant_conf,
};
self.upsert_location(
*child_shard,
child_location_conf,
None,
SpawnMode::Normal,
ctx,
)
.await?;
}
// Phase 4: wait for child chards WAL ingest to catch up to target LSN
for child_shard_id in &child_shards {
let child_shard = {
let locked = TENANTS.read().unwrap();
let peek_slot =
tenant_map_peek_slot(&locked, child_shard_id, TenantSlotPeekMode::Read)?;
peek_slot.and_then(|s| s.get_attached()).cloned()
};
if let Some(t) = child_shard {
let timelines = t.timelines.lock().unwrap().clone();
for timeline in timelines.values() {
let Some(target_lsn) = target_lsns.get(&timeline.timeline_id) else {
continue;
};
tracing::info!(
"Waiting for child shard {}/{} to reach target lsn {}...",
child_shard_id,
timeline.timeline_id,
target_lsn
);
if let Err(e) = timeline.wait_lsn(*target_lsn, ctx).await {
// Failure here might mean shutdown, in any case this part is an optimization
// and we shouldn't hold up the split operation.
tracing::warn!(
"Failed to wait for timeline {} to reach lsn {target_lsn}: {e}",
timeline.timeline_id
);
} else {
tracing::info!(
"Child shard {}/{} reached target lsn {}",
child_shard_id,
timeline.timeline_id,
target_lsn
);
}
}
}
}
// Phase 5: Shut down the parent shard.
let (_guard, progress) = completion::channel();
match parent.shutdown(progress, false).await {
Ok(()) => {}
Err(other) => {
other.wait().await;
}
}
parent_slot_guard.drop_old_value()?;
// Phase 6: Release the InProgress on the parent shard
drop(parent_slot_guard);
Ok(child_shards)
}
}
#[derive(Debug, thiserror::Error)]
@@ -2209,8 +2366,6 @@ async fn remove_tenant_from_memory<V, F>(
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
use utils::completion;
let mut slot_guard =
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;

View File

@@ -217,6 +217,7 @@ use crate::metrics::{
};
use crate::task_mgr::shutdown_token;
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::remote_timeline_client::download::download_retry;
use crate::tenant::storage_layer::AsLayerDesc;
use crate::tenant::upload_queue::Delete;
use crate::tenant::TIMELINES_SEGMENT_NAME;
@@ -262,6 +263,11 @@ pub(crate) const INITDB_PRESERVED_PATH: &str = "initdb-preserved.tar.zst";
/// Default buffer size when interfacing with [`tokio::fs::File`].
pub(crate) const BUFFER_SIZE: usize = 32 * 1024;
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
pub(crate) const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
pub(crate) const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
pub enum MaybeDeletedIndexPart {
IndexPart(IndexPart),
Deleted(IndexPart),
@@ -325,11 +331,6 @@ pub struct RemoteTimelineClient {
cancel: CancellationToken,
}
/// This timeout is intended to deal with hangs in lower layers, e.g. stuck TCP flows. It is not
/// intended to be snappy enough for prompt shutdown, as we have a CancellationToken for that.
const UPLOAD_TIMEOUT: Duration = Duration::from_secs(120);
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(120);
/// Wrapper for timeout_cancellable that flattens result and converts TimeoutCancellableError to anyhow.
///
/// This is a convenience for the various upload functions. In future
@@ -506,7 +507,7 @@ impl RemoteTimelineClient {
/// Download index file
pub async fn download_index_file(
&self,
cancel: CancellationToken,
cancel: &CancellationToken,
) -> Result<MaybeDeletedIndexPart, DownloadError> {
let _unfinished_gauge_guard = self.metrics.call_begin(
&RemoteOpFileKind::Index,
@@ -1147,22 +1148,17 @@ impl RemoteTimelineClient {
let cancel = shutdown_token();
let remaining = backoff::retry(
let remaining = download_retry(
|| async {
self.storage_impl
.list_files(Some(&timeline_storage_path))
.list_files(Some(&timeline_storage_path), None)
.await
},
|_e| false,
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
"list_prefixes",
"list remaining files",
&cancel,
)
.await
.ok_or_else(|| anyhow::anyhow!("Cancelled!"))
.and_then(|x| x)
.context("list prefixes")?;
.context("list files remaining files")?;
// We will delete the current index_part object last, since it acts as a deletion
// marker via its deleted_at attribute
@@ -1351,6 +1347,7 @@ impl RemoteTimelineClient {
/// queue.
///
async fn perform_upload_task(self: &Arc<Self>, task: Arc<UploadTask>) {
let cancel = shutdown_token();
// Loop to retry until it completes.
loop {
// If we're requested to shut down, close up shop and exit.
@@ -1362,7 +1359,7 @@ impl RemoteTimelineClient {
// the Future, but we're not 100% sure if the remote storage library
// is cancellation safe, so we don't dare to do that. Hopefully, the
// upload finishes or times out soon enough.
if task_mgr::is_shutdown_requested() {
if cancel.is_cancelled() {
info!("upload task cancelled by shutdown request");
match self.stop() {
Ok(()) => {}
@@ -1473,7 +1470,7 @@ impl RemoteTimelineClient {
retries,
DEFAULT_BASE_BACKOFF_SECONDS,
DEFAULT_MAX_BACKOFF_SECONDS,
&shutdown_token(),
&cancel,
)
.await;
}
@@ -1990,7 +1987,7 @@ mod tests {
// Download back the index.json, and check that the list of files is correct
let initial_index_part = match client
.download_index_file(CancellationToken::new())
.download_index_file(&CancellationToken::new())
.await
.unwrap()
{
@@ -2084,7 +2081,7 @@ mod tests {
// Download back the index.json, and check that the list of files is correct
let index_part = match client
.download_index_file(CancellationToken::new())
.download_index_file(&CancellationToken::new())
.await
.unwrap()
{
@@ -2286,7 +2283,7 @@ mod tests {
let client = test_state.build_client(get_generation);
let download_r = client
.download_index_file(CancellationToken::new())
.download_index_file(&CancellationToken::new())
.await
.expect("download should always succeed");
assert!(matches!(download_r, MaybeDeletedIndexPart::IndexPart(_)));

View File

@@ -216,16 +216,15 @@ pub async fn list_remote_timelines(
anyhow::bail!("storage-sync-list-remote-timelines");
});
let cancel_inner = cancel.clone();
let listing = download_retry_forever(
|| {
download_cancellable(
&cancel_inner,
storage.list(Some(&remote_path), ListingMode::WithDelimiter),
&cancel,
storage.list(Some(&remote_path), ListingMode::WithDelimiter, None),
)
},
&format!("list timelines for {tenant_shard_id}"),
cancel,
&cancel,
)
.await?;
@@ -258,19 +257,18 @@ async fn do_download_index_part(
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,
index_generation: Generation,
cancel: CancellationToken,
cancel: &CancellationToken,
) -> Result<IndexPart, DownloadError> {
use futures::stream::StreamExt;
let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation);
let cancel_inner = cancel.clone();
let index_part_bytes = download_retry_forever(
|| async {
// Cancellation: if is safe to cancel this future because we're just downloading into
// a memory buffer, not touching local disk.
let index_part_download =
download_cancellable(&cancel_inner, storage.download(&remote_path)).await?;
download_cancellable(cancel, storage.download(&remote_path)).await?;
let mut index_part_bytes = Vec::new();
let mut stream = std::pin::pin!(index_part_download.download_stream);
@@ -288,7 +286,7 @@ async fn do_download_index_part(
.await?;
let index_part: IndexPart = serde_json::from_slice(&index_part_bytes)
.with_context(|| format!("download index part file at {remote_path:?}"))
.with_context(|| format!("deserialize index part file at {remote_path:?}"))
.map_err(DownloadError::Other)?;
Ok(index_part)
@@ -305,7 +303,7 @@ pub(super) async fn download_index_part(
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,
my_generation: Generation,
cancel: CancellationToken,
cancel: &CancellationToken,
) -> Result<IndexPart, DownloadError> {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -325,14 +323,8 @@ pub(super) async fn download_index_part(
// index in our generation.
//
// This is an optimization to avoid doing the listing for the general case below.
let res = do_download_index_part(
storage,
tenant_shard_id,
timeline_id,
my_generation,
cancel.clone(),
)
.await;
let res =
do_download_index_part(storage, tenant_shard_id, timeline_id, my_generation, cancel).await;
match res {
Ok(index_part) => {
tracing::debug!(
@@ -357,7 +349,7 @@ pub(super) async fn download_index_part(
tenant_shard_id,
timeline_id,
my_generation.previous(),
cancel.clone(),
cancel,
)
.await;
match res {
@@ -379,18 +371,13 @@ pub(super) async fn download_index_part(
// objects, and select the highest one with a generation <= my_generation. Constructing the prefix is equivalent
// to constructing a full index path with no generation, because the generation is a suffix.
let index_prefix = remote_index_path(tenant_shard_id, timeline_id, Generation::none());
let indices = backoff::retry(
|| async { storage.list_files(Some(&index_prefix)).await },
|_| false,
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
"listing index_part files",
&cancel,
let indices = download_retry(
|| async { storage.list_files(Some(&index_prefix), None).await },
"list index_part files",
cancel,
)
.await
.ok_or_else(|| anyhow::anyhow!("Cancelled"))
.and_then(|x| x)
.map_err(DownloadError::Other)?;
.await?;
// General case logic for which index to use: the latest index whose generation
// is <= our own. See "Finding the remote indices for timelines" in docs/rfcs/025-generation-numbers.md
@@ -447,8 +434,6 @@ pub(crate) async fn download_initdb_tar_zst(
"{INITDB_PATH}.download-{timeline_id}.{TEMP_FILE_SUFFIX}"
));
let cancel_inner = cancel.clone();
let file = download_retry(
|| async {
let file = OpenOptions::new()
@@ -461,13 +446,11 @@ pub(crate) async fn download_initdb_tar_zst(
.with_context(|| format!("tempfile creation {temp_path}"))
.map_err(DownloadError::Other)?;
let download = match download_cancellable(&cancel_inner, storage.download(&remote_path))
.await
let download = match download_cancellable(cancel, storage.download(&remote_path)).await
{
Ok(dl) => dl,
Err(DownloadError::NotFound) => {
download_cancellable(&cancel_inner, storage.download(&remote_preserved_path))
.await?
download_cancellable(cancel, storage.download(&remote_preserved_path)).await?
}
Err(other) => Err(other)?,
};
@@ -516,7 +499,7 @@ pub(crate) async fn download_initdb_tar_zst(
/// with backoff.
///
/// (See similar logic for uploads in `perform_upload_task`)
async fn download_retry<T, O, F>(
pub(super) async fn download_retry<T, O, F>(
op: O,
description: &str,
cancel: &CancellationToken,
@@ -527,7 +510,7 @@ where
{
backoff::retry(
op,
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
DownloadError::is_permanent,
FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
description,
@@ -541,7 +524,7 @@ where
async fn download_retry_forever<T, O, F>(
op: O,
description: &str,
cancel: CancellationToken,
cancel: &CancellationToken,
) -> Result<T, DownloadError>
where
O: FnMut() -> F,
@@ -549,11 +532,11 @@ where
{
backoff::retry(
op,
|e| matches!(e, DownloadError::BadInput(_) | DownloadError::NotFound),
DownloadError::is_permanent,
FAILED_DOWNLOAD_WARN_THRESHOLD,
u32::MAX,
description,
&cancel,
cancel,
)
.await
.ok_or_else(|| DownloadError::Cancelled)

View File

@@ -27,7 +27,7 @@ use super::index::LayerFileMetadata;
use tracing::info;
/// Serializes and uploads the given index part data to the remote storage.
pub(super) async fn upload_index_part<'a>(
pub(crate) async fn upload_index_part<'a>(
storage: &'a GenericRemoteStorage,
tenant_shard_id: &TenantShardId,
timeline_id: &TimelineId,

View File

@@ -160,7 +160,7 @@ impl SecondaryTenant {
&self.tenant_shard_id
}
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> DiskUsageEvictionInfo {
pub(crate) fn get_layers_for_eviction(self: &Arc<Self>) -> (DiskUsageEvictionInfo, usize) {
self.detail.lock().unwrap().get_layers_for_eviction(self)
}

View File

@@ -146,14 +146,15 @@ impl SecondaryDetail {
}
}
/// Additionally returns the total number of layers, used for more stable relative access time
/// based eviction.
pub(super) fn get_layers_for_eviction(
&self,
parent: &Arc<SecondaryTenant>,
) -> DiskUsageEvictionInfo {
let mut result = DiskUsageEvictionInfo {
max_layer_size: None,
resident_layers: Vec::new(),
};
) -> (DiskUsageEvictionInfo, usize) {
let mut result = DiskUsageEvictionInfo::default();
let mut total_layers = 0;
for (timeline_id, timeline_detail) in &self.timelines {
result
.resident_layers
@@ -169,6 +170,10 @@ impl SecondaryDetail {
relative_last_activity: finite_f32::FiniteF32::ZERO,
}
}));
// total might be missing currently downloading layers, but as a lower than actual
// value it is good enough approximation.
total_layers += timeline_detail.on_disk_layers.len() + timeline_detail.evicted_at.len();
}
result.max_layer_size = result
.resident_layers
@@ -183,7 +188,7 @@ impl SecondaryDetail {
result.resident_layers.len()
);
result
(result, total_layers)
}
}
@@ -312,9 +317,7 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
.tenant_manager
.get_secondary_tenant_shard(*tenant_shard_id);
let Some(tenant) = tenant else {
{
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
}
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
};
Ok(PendingDownload {
@@ -389,9 +392,9 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
}
CompleteDownload {
secondary_state,
completed_at: Instant::now(),
}
secondary_state,
completed_at: Instant::now(),
}
}.instrument(info_span!(parent: None, "secondary_download", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))))
}
}
@@ -530,7 +533,7 @@ impl<'a> TenantDownloader<'a> {
.map_err(UpdateError::from)?;
let mut heatmap_bytes = Vec::new();
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
let _size = tokio::io::copy(&mut body, &mut heatmap_bytes).await?;
let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?;
Ok(heatmap_bytes)
},
|e| matches!(e, UpdateError::NoData | UpdateError::Cancelled),

View File

@@ -300,8 +300,8 @@ impl Layer {
})
}
pub(crate) async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
self.0.info(reset).await
pub(crate) fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
self.0.info(reset)
}
pub(crate) fn access_stats(&self) -> &LayerAccessStats {
@@ -612,10 +612,10 @@ impl LayerInner {
let mut rx = self.status.subscribe();
let strong = {
match self.inner.get_mut().await {
match self.inner.get() {
Some(mut either) => {
self.wanted_evicted.store(true, Ordering::Relaxed);
ResidentOrWantedEvicted::downgrade(&mut either)
either.downgrade()
}
None => return Err(EvictionError::NotFound),
}
@@ -641,7 +641,7 @@ impl LayerInner {
// use however late (compared to the initial expressing of wanted) as the
// "outcome" now
LAYER_IMPL_METRICS.inc_broadcast_lagged();
match self.inner.get_mut().await {
match self.inner.get() {
Some(_) => Err(EvictionError::Downloaded),
None => Ok(()),
}
@@ -759,7 +759,7 @@ impl LayerInner {
// use the already held initialization permit because it is impossible to hit the
// below paths anymore essentially limiting the max loop iterations to 2.
let (value, init_permit) = download(init_permit).await?;
let mut guard = self.inner.set(value, init_permit).await;
let mut guard = self.inner.set(value, init_permit);
let (strong, _upgraded) = guard
.get_and_upgrade()
.expect("init creates strong reference, we held the init permit");
@@ -767,7 +767,7 @@ impl LayerInner {
}
let (weak, permit) = {
let mut locked = self.inner.get_mut_or_init(download).await?;
let mut locked = self.inner.get_or_init(download).await?;
if let Some((strong, upgraded)) = locked.get_and_upgrade() {
if upgraded {
@@ -989,12 +989,12 @@ impl LayerInner {
}
}
async fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
let layer_file_name = self.desc.filename().file_name();
// this is not accurate: we could have the file locally but there was a cancellation
// and now we are not in sync, or we are currently downloading it.
let remote = self.inner.get_mut().await.is_none();
let remote = self.inner.get().is_none();
let access_stats = self.access_stats.as_api_model(reset);
@@ -1053,7 +1053,7 @@ impl LayerInner {
LAYER_IMPL_METRICS.inc_eviction_cancelled(EvictionCancelled::LayerGone);
return;
};
match tokio::runtime::Handle::current().block_on(this.evict_blocking(version)) {
match this.evict_blocking(version) {
Ok(()) => LAYER_IMPL_METRICS.inc_completed_evictions(),
Err(reason) => LAYER_IMPL_METRICS.inc_eviction_cancelled(reason),
}
@@ -1061,7 +1061,7 @@ impl LayerInner {
}
}
async fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
fn evict_blocking(&self, only_version: usize) -> Result<(), EvictionCancelled> {
// deleted or detached timeline, don't do anything.
let Some(timeline) = self.timeline.upgrade() else {
return Err(EvictionCancelled::TimelineGone);
@@ -1070,7 +1070,7 @@ impl LayerInner {
// to avoid starting a new download while we evict, keep holding on to the
// permit.
let _permit = {
let maybe_downloaded = self.inner.get_mut().await;
let maybe_downloaded = self.inner.get();
let (_weak, permit) = match maybe_downloaded {
Some(mut guard) => {

View File

@@ -1268,7 +1268,7 @@ impl Timeline {
let mut historic_layers = Vec::new();
for historic_layer in layer_map.iter_historic_layers() {
let historic_layer = guard.get_from_desc(&historic_layer);
historic_layers.push(historic_layer.info(reset).await);
historic_layers.push(historic_layer.info(reset));
}
LayerMapInfo {

View File

@@ -343,6 +343,23 @@ pub(super) async fn handle_walreceiver_connection(
modification.commit(&ctx).await?;
uncommitted_records = 0;
filtered_records = 0;
//
// We should check checkpoint distance after appending each ingest_batch_size bytes because otherwise
// layer size can become much larger than `checkpoint_distance`.
// It can append because wal-sender is sending WAL using 125kb chucks and some WAL records can cause writing large
// amount of data to key-value storage. So performing this check only after processing
// all WAL records in the chunk, can cause huge L0 layer files.
//
timeline
.check_checkpoint_distance()
.await
.with_context(|| {
format!(
"Failed to check checkpoint distance for timeline {}",
timeline.timeline_id
)
})?;
}
}

View File

@@ -28,9 +28,10 @@ use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
use utils::fs_ext;
mod io_engine;
pub use pageserver_api::models::virtual_file as api;
pub(crate) mod io_engine;
mod open_options;
pub use io_engine::IoEngineKind;
pub(crate) use io_engine::IoEngineKind;
pub(crate) use open_options::*;
///

View File

@@ -7,67 +7,100 @@
//!
//! Then use [`get`] and [`super::OpenOptions`].
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
pub enum IoEngineKind {
pub(crate) use super::api::IoEngineKind;
#[derive(Clone, Copy)]
#[repr(u8)]
pub(crate) enum IoEngine {
NotSet,
StdFs,
#[cfg(target_os = "linux")]
TokioEpollUring,
}
static IO_ENGINE: once_cell::sync::OnceCell<IoEngineKind> = once_cell::sync::OnceCell::new();
#[cfg(not(test))]
pub(super) fn init(engine: IoEngineKind) {
if IO_ENGINE.set(engine).is_err() {
panic!("called twice");
impl From<IoEngineKind> for IoEngine {
fn from(value: IoEngineKind) -> Self {
match value {
IoEngineKind::StdFs => IoEngine::StdFs,
#[cfg(target_os = "linux")]
IoEngineKind::TokioEpollUring => IoEngine::TokioEpollUring,
}
}
crate::metrics::virtual_file_io_engine::KIND
.with_label_values(&[&format!("{engine}")])
.set(1);
}
pub(super) fn get() -> &'static IoEngineKind {
#[cfg(test)]
{
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE";
IO_ENGINE.get_or_init(|| match std::env::var(env_var_name) {
Ok(v) => match v.parse::<IoEngineKind>() {
Ok(engine_kind) => engine_kind,
Err(e) => {
panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}")
}
},
Err(std::env::VarError::NotPresent) => {
crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE
.parse()
.unwrap()
}
Err(std::env::VarError::NotUnicode(_)) => {
panic!("env var {env_var_name} is not unicode");
}
impl TryFrom<u8> for IoEngine {
type Error = u8;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (IoEngine::NotSet as u8) => IoEngine::NotSet,
v if v == (IoEngine::StdFs as u8) => IoEngine::StdFs,
#[cfg(target_os = "linux")]
v if v == (IoEngine::TokioEpollUring as u8) => IoEngine::TokioEpollUring,
x => return Err(x),
})
}
#[cfg(not(test))]
IO_ENGINE.get().unwrap()
}
use std::os::unix::prelude::FileExt;
static IO_ENGINE: AtomicU8 = AtomicU8::new(IoEngine::NotSet as u8);
pub(crate) fn set(engine_kind: IoEngineKind) {
let engine: IoEngine = engine_kind.into();
IO_ENGINE.store(engine as u8, std::sync::atomic::Ordering::Relaxed);
#[cfg(not(test))]
{
let metric = &crate::metrics::virtual_file_io_engine::KIND;
metric.reset();
metric
.with_label_values(&[&format!("{engine_kind}")])
.set(1);
}
}
#[cfg(not(test))]
pub(super) fn init(engine_kind: IoEngineKind) {
set(engine_kind);
}
pub(super) fn get() -> IoEngine {
let cur = IoEngine::try_from(IO_ENGINE.load(Ordering::Relaxed)).unwrap();
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE";
match cur {
IoEngine::NotSet => {
let kind = match std::env::var(env_var_name) {
Ok(v) => match v.parse::<IoEngineKind>() {
Ok(engine_kind) => engine_kind,
Err(e) => {
panic!("invalid VirtualFile io engine for env var {env_var_name}: {e:#}: {v:?}")
}
},
Err(std::env::VarError::NotPresent) => {
crate::config::defaults::DEFAULT_VIRTUAL_FILE_IO_ENGINE
.parse()
.unwrap()
}
Err(std::env::VarError::NotUnicode(_)) => {
panic!("env var {env_var_name} is not unicode");
}
};
self::set(kind);
self::get()
}
x => x,
}
} else {
cur
}
}
use std::{
os::unix::prelude::FileExt,
sync::atomic::{AtomicU8, Ordering},
};
use super::FileGuard;
impl IoEngineKind {
impl IoEngine {
pub(super) async fn read_at<B>(
&self,
file_guard: FileGuard,
@@ -78,7 +111,8 @@ impl IoEngineKind {
B: tokio_epoll_uring::BoundedBufMut + Send,
{
match self {
IoEngineKind::StdFs => {
IoEngine::NotSet => panic!("not initialized"),
IoEngine::StdFs => {
// SAFETY: `dst` only lives at most as long as this match arm, during which buf remains valid memory.
let dst = unsafe {
std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total())
@@ -96,7 +130,7 @@ impl IoEngineKind {
((file_guard, buf), res)
}
#[cfg(target_os = "linux")]
IoEngineKind::TokioEpollUring => {
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring::thread_local_system().await;
let (resources, res) = system.read(file_guard, offset, buf).await;
(

View File

@@ -1,6 +1,6 @@
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
use super::IoEngineKind;
use super::io_engine::IoEngine;
use std::{os::fd::OwnedFd, path::Path};
#[derive(Debug, Clone)]
@@ -13,9 +13,10 @@ pub enum OpenOptions {
impl Default for OpenOptions {
fn default() -> Self {
match super::io_engine::get() {
IoEngineKind::StdFs => Self::StdFs(std::fs::OpenOptions::new()),
IoEngine::NotSet => panic!("io engine not set"),
IoEngine::StdFs => Self::StdFs(std::fs::OpenOptions::new()),
#[cfg(target_os = "linux")]
IoEngineKind::TokioEpollUring => {
IoEngine::TokioEpollUring => {
Self::TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions::new())
}
}

View File

@@ -314,6 +314,9 @@ lfc_change_limit_hook(int newval, void *extra)
lfc_ctl->used -= 1;
}
lfc_ctl->limit = new_size;
if (new_size == 0) {
lfc_ctl->generation += 1;
}
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);

View File

@@ -11,16 +11,23 @@
#include "postgres.h"
#include "fmgr.h"
#include "miscadmin.h"
#include "access/xact.h"
#include "access/xlog.h"
#include "storage/buf_internals.h"
#include "storage/bufmgr.h"
#include "catalog/pg_type.h"
#include "postmaster/bgworker.h"
#include "postmaster/interrupt.h"
#include "replication/slot.h"
#include "replication/walsender.h"
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "funcapi.h"
#include "access/htup_details.h"
#include "utils/pg_lsn.h"
#include "utils/guc.h"
#include "utils/wait_event.h"
#include "neon.h"
#include "walproposer.h"
@@ -30,6 +37,130 @@
PG_MODULE_MAGIC;
void _PG_init(void);
static int logical_replication_max_time_lag = 3600;
static void
InitLogicalReplicationMonitor(void)
{
BackgroundWorker bgw;
DefineCustomIntVariable(
"neon.logical_replication_max_time_lag",
"Threshold for dropping unused logical replication slots",
NULL,
&logical_replication_max_time_lag,
3600, 0, INT_MAX,
PGC_SIGHUP,
GUC_UNIT_S,
NULL, NULL, NULL);
memset(&bgw, 0, sizeof(bgw));
bgw.bgw_flags = BGWORKER_SHMEM_ACCESS;
bgw.bgw_start_time = BgWorkerStart_RecoveryFinished;
snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon");
snprintf(bgw.bgw_function_name, BGW_MAXLEN, "LogicalSlotsMonitorMain");
snprintf(bgw.bgw_name, BGW_MAXLEN, "Logical replication monitor");
snprintf(bgw.bgw_type, BGW_MAXLEN, "Logical replication monitor");
bgw.bgw_restart_time = 5;
bgw.bgw_notify_pid = 0;
bgw.bgw_main_arg = (Datum) 0;
RegisterBackgroundWorker(&bgw);
}
typedef struct
{
NameData name;
bool dropped;
XLogRecPtr confirmed_flush_lsn;
TimestampTz last_updated;
} SlotStatus;
/*
* Unused logical replication slots pins WAL and prevents deletion of snapshots.
*/
PGDLLEXPORT void
LogicalSlotsMonitorMain(Datum main_arg)
{
SlotStatus* slots;
TimestampTz now, last_checked;
/* Establish signal handlers. */
pqsignal(SIGUSR1, procsignal_sigusr1_handler);
pqsignal(SIGHUP, SignalHandlerForConfigReload);
pqsignal(SIGTERM, die);
BackgroundWorkerUnblockSignals();
slots = (SlotStatus*)calloc(max_replication_slots, sizeof(SlotStatus));
last_checked = GetCurrentTimestamp();
for (;;)
{
(void) WaitLatch(MyLatch,
WL_LATCH_SET | WL_EXIT_ON_PM_DEATH | WL_TIMEOUT,
logical_replication_max_time_lag*1000/2,
PG_WAIT_EXTENSION);
ResetLatch(MyLatch);
CHECK_FOR_INTERRUPTS();
now = GetCurrentTimestamp();
if (now - last_checked > logical_replication_max_time_lag*USECS_PER_SEC)
{
int n_active_slots = 0;
last_checked = now;
LWLockAcquire(ReplicationSlotControlLock, LW_SHARED);
for (int i = 0; i < max_replication_slots; i++)
{
ReplicationSlot *s = &ReplicationSlotCtl->replication_slots[i];
/* Consider only logical repliction slots */
if (!s->in_use || !SlotIsLogical(s))
continue;
if (s->active_pid != 0)
{
n_active_slots += 1;
continue;
}
/* Check if there was some activity with the slot since last check */
if (s->data.confirmed_flush != slots[i].confirmed_flush_lsn)
{
slots[i].confirmed_flush_lsn = s->data.confirmed_flush;
slots[i].last_updated = now;
}
else if (now - slots[i].last_updated > logical_replication_max_time_lag*USECS_PER_SEC)
{
slots[i].name = s->data.name;
slots[i].dropped = true;
}
}
LWLockRelease(ReplicationSlotControlLock);
/*
* If there are no active subscriptions, then no new snapshots are generated
* and so no need to force slot deletion.
*/
if (n_active_slots != 0)
{
for (int i = 0; i < max_replication_slots; i++)
{
if (slots[i].dropped)
{
elog(LOG, "Drop logical replication slot because it was not update more than %ld seconds",
(now - slots[i].last_updated)/USECS_PER_SEC);
ReplicationSlotDrop(slots[i].name.data, true);
slots[i].dropped = false;
}
}
}
}
}
}
void
_PG_init(void)
{
@@ -44,6 +175,8 @@ _PG_init(void)
pg_init_libpagestore();
pg_init_walproposer();
InitLogicalReplicationMonitor();
InitControlPlaneConnector();
pg_init_extension_server();

View File

@@ -19,6 +19,7 @@ chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
futures.workspace = true
git-version.workspace = true
hashbrown.workspace = true
@@ -59,6 +60,8 @@ scopeguard.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
smol_str.workspace = true
smallvec.workspace = true
socket2.workspace = true
sync_wrapper.workspace = true
task-local-extensions.workspace = true
@@ -75,6 +78,7 @@ tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
url.workspace = true
urlencoding.workspace = true
utils.workspace = true
uuid.workspace = true
webpki-roots.workspace = true
@@ -83,7 +87,6 @@ native-tls.workspace = true
postgres-native-tls.workspace = true
postgres-protocol.workspace = true
redis.workspace = true
smol_str.workspace = true
workspace_hack.workspace = true

View File

@@ -5,7 +5,8 @@ pub use backend::BackendType;
mod credentials;
pub use credentials::{
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint, IpPattern,
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
ComputeUserInfoParseError, IpPattern,
};
mod password_hack;
@@ -14,8 +15,12 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use tokio::time::error::Elapsed;
use crate::{console, error::UserFacingError};
use crate::{
console,
error::{ReportableError, UserFacingError},
};
use std::io;
use thiserror::Error;
@@ -67,6 +72,9 @@ pub enum AuthErrorImpl {
#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,
#[error("Authentication timed out")]
UserTimeout(Elapsed),
}
#[derive(Debug, Error)]
@@ -93,6 +101,10 @@ impl AuthError {
pub fn is_auth_failed(&self) -> bool {
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}
pub fn user_timeout(elapsed: Elapsed) -> Self {
AuthErrorImpl::UserTimeout(elapsed).into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
@@ -116,6 +128,27 @@ impl UserFacingError for AuthError {
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
WakeCompute(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,
MalformedPassword(_) => crate::error::ErrorKind::User,
MissingEndpointName => crate::error::ErrorKind::User,
Io(_) => crate::error::ErrorKind::ClientDisconnect,
IpAddressNotAllowed => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}

View File

@@ -68,6 +68,7 @@ pub trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
}
impl std::fmt::Display for BackendType<'_, ()> {
@@ -358,6 +359,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
}
impl BackendType<'_, ComputeUserInfo> {
pub async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(None)),
}
}
pub async fn get_allowed_ips_and_secret(
&self,
ctx: &mut RequestMonitoring,

View File

@@ -45,9 +45,9 @@ pub(super) async fn authenticate(
}
)
.await
.map_err(|error| {
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {

View File

@@ -2,7 +2,7 @@ use crate::{
auth, compute,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::UserFacingError,
error::{ReportableError, UserFacingError},
stream::PqStream,
waiters,
};
@@ -14,10 +14,6 @@ use tracing::{info, info_span};
#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
@@ -30,10 +26,16 @@ pub enum LinkAuthError {
impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
"Internal error".to_string()
}
}
impl ReportableError for LinkAuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
LinkAuthError::WaiterRegister(_) => crate::error::ErrorKind::Service,
LinkAuthError::WaiterWait(_) => crate::error::ErrorKind::Service,
LinkAuthError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}

View File

@@ -1,8 +1,12 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, serverless::SERVERLESS_DRIVER_SNI,
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
@@ -39,6 +43,12 @@ pub enum ComputeUserInfoParseError {
impl UserFacingError for ComputeUserInfoParseError {}
impl ReportableError for ComputeUserInfoParseError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@@ -167,7 +167,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
}
pub(super) fn validate_password_and_exchange(
pub(crate) fn validate_password_and_exchange(
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {

View File

@@ -240,7 +240,9 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
?unexpected,
"unexpected startup packet, rejecting connection"
);
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
stream
.throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
.await?
}
}
}
@@ -272,5 +274,10 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::passthrough::proxy_pass(ctx, tls_stream, client, metrics_aux).await
// doesn't yet matter as pg-sni-router doesn't report analytics logs
ctx.set_success();
ctx.log();
proxy::proxy::passthrough::proxy_pass(tls_stream, client, metrics_aux).await
}

View File

@@ -88,6 +88,12 @@ struct ProxyCliArgs {
/// path to directory with TLS certificates for client postgres connections
#[clap(long)]
certs_dir: Option<String>,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// timeout for the control plane requests
#[clap(long, default_value = "70s", value_parser = humantime::parse_duration)]
cplane_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
@@ -165,6 +171,10 @@ struct SqlOverHttpArgs {
#[clap(long, default_value_t = 20)]
sql_over_http_pool_max_conns_per_endpoint: usize,
/// How many connections to pool for each endpoint. Excess connections are discarded
#[clap(long, default_value_t = 20000)]
sql_over_http_pool_max_total_conns: usize,
/// How long pooled connections should remain idle for before closing
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
sql_over_http_idle_timeout: tokio::time::Duration,
@@ -361,7 +371,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
tokio::spawn(locks.garbage_collect_worker(epoch));
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
let endpoint = http::Endpoint::new(
url,
http::new_client(rate_limiter_config, args.cplane_timeout),
);
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let api = console::provider::ConsoleBackend::Console(api);
@@ -387,6 +400,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
};
let authentication_config = AuthenticationConfig {
@@ -406,6 +420,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),
}));

View File

@@ -1,24 +1,45 @@
use anyhow::Context;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use crate::error::ReportableError;
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
#[derive(Debug, Error)]
pub enum CancelError {
#[error("{0}")]
IO(#[from] std::io::Error),
#[error("{0}")]
Postgres(#[from] tokio_postgres::Error),
}
impl ReportableError for CancelError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
CancelError::IO(_) => crate::error::ErrorKind::Compute,
CancelError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
}
}
}
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let cancel_closure = self
.0
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
return Ok(());
};
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
@@ -81,7 +102,7 @@ impl CancelClosure {
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
async fn try_cancel_query(self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;

View File

@@ -1,6 +1,10 @@
use crate::{
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
@@ -58,6 +62,20 @@ impl UserFacingError for ConnectionError {
}
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
}
}
}
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;

View File

@@ -22,6 +22,7 @@ pub struct ProxyConfig {
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
}
#[derive(Debug)]

View File

@@ -20,7 +20,7 @@ use tracing::info;
pub mod errors {
use crate::{
error::{io_error, UserFacingError},
error::{io_error, ReportableError, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
@@ -81,6 +81,15 @@ pub mod errors {
}
}
impl ReportableError for ApiError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl ShouldRetry for ApiError {
fn could_retry(&self) -> bool {
match self {
@@ -150,6 +159,16 @@ pub mod errors {
}
}
}
impl ReportableError for GetAuthInfoError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
#[derive(Debug, Error)]
pub enum WakeComputeError {
#[error("Console responded with a malformed compute address: {0}")]
@@ -194,6 +213,16 @@ pub mod errors {
}
}
}
impl ReportableError for WakeComputeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_kind(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
}
/// Auth secret which is managed by the cloud.

View File

@@ -188,6 +188,7 @@ impl super::Api for Api {
ep,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
@@ -221,6 +222,7 @@ impl super::Api for Api {
self.caches
.project_info
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
ctx.set_project_id(project_id);
}
Ok((
Cached::new_uncached(allowed_ips),

View File

@@ -8,8 +8,10 @@ use tokio::sync::mpsc;
use uuid::Uuid;
use crate::{
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
EndpointId, ProjectId, RoleName,
console::messages::MetricsAuxInfo,
error::ErrorKind,
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
BranchId, EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -89,6 +91,10 @@ impl RequestMonitoring {
self.project = Some(x.project_id);
}
pub fn set_project_id(&mut self, project_id: ProjectId) {
self.project = Some(project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
crate::metrics::CONNECTING_ENDPOINTS
.with_label_values(&[self.protocol])
@@ -104,6 +110,18 @@ impl RequestMonitoring {
self.user = Some(user);
}
pub fn set_error_kind(&mut self, kind: ErrorKind) {
ERROR_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.inc();
if let Some(ep) = &self.endpoint_id {
ENDPOINT_ERRORS_BY_KIND
.with_label_values(&[kind.to_metric_label()])
.measure(ep);
}
self.error_kind = Some(kind);
}
pub fn set_success(&mut self) {
self.success = true;
}

View File

@@ -108,7 +108,7 @@ impl From<RequestMonitoring> for RequestData {
branch: value.branch.as_deref().map(String::from),
protocol: value.protocol,
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_str()),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
duration_us: SystemTime::from(value.first_packet)
.elapsed()

View File

@@ -17,7 +17,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: fmt::Display {
pub trait UserFacingError: ReportableError {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
@@ -29,13 +29,13 @@ pub trait UserFacingError: fmt::Display {
}
}
#[derive(Clone)]
#[derive(Copy, Clone, Debug)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
/// Network error between user and proxy. Not necessarily user error
Disconnect,
ClientDisconnect,
/// Proxy self-imposed rate limits
RateLimit,
@@ -46,6 +46,9 @@ pub enum ErrorKind {
/// Error communicating with control plane
ControlPlane,
/// Postgres error
Postgres,
/// Error communicating with compute
Compute,
}
@@ -54,11 +57,36 @@ impl ErrorKind {
pub fn to_str(&self) -> &'static str {
match self {
ErrorKind::User => "request failed due to user error",
ErrorKind::Disconnect => "client disconnected",
ErrorKind::ClientDisconnect => "client disconnected",
ErrorKind::RateLimit => "request cancelled due to rate limit",
ErrorKind::Service => "internal service error",
ErrorKind::ControlPlane => "non-retryable control plane error",
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
ErrorKind::Postgres => "postgres error",
ErrorKind::Compute => {
"non-retryable compute connection error (or exhausted retry capacity)"
}
}
}
pub fn to_metric_label(&self) -> &'static str {
match self {
ErrorKind::User => "user",
ErrorKind::ClientDisconnect => "clientdisconnect",
ErrorKind::RateLimit => "ratelimit",
ErrorKind::Service => "service",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "postgres",
ErrorKind::Compute => "compute",
}
}
}
pub trait ReportableError: fmt::Display + Send + 'static {
fn get_error_kind(&self) -> ErrorKind;
}
impl ReportableError for tokio::time::error::Elapsed {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::RateLimit
}
}

View File

@@ -19,10 +19,14 @@ use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
/// We deliberately don't want to replace this with a public static.
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
pub fn new_client(
rate_limiter_config: rate_limiter::RateLimiterConfig,
timeout: Duration,
) -> ClientWithMiddleware {
let client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.connection_verbose(true)
.timeout(timeout)
.build()
.expect("Failed to create http client");

View File

@@ -1,8 +1,10 @@
use ::metrics::{
exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram,
HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge,
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
IntCounterVec, IntGauge, IntGaugeVec,
};
use metrics::{register_int_counter_pair, IntCounterPair};
use once_cell::sync::Lazy;
use tokio::time;
@@ -112,6 +114,44 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
.unwrap()
});
pub static HTTP_CONTENT_LENGTH: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_conn_content_length_bytes",
"Time it took for proxy to establish a connection to the compute endpoint",
// largest bucket = 3^16 * 0.05ms = 2.15s
exponential_buckets(8.0, 2.0, 20).unwrap()
)
.unwrap()
});
pub static GC_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_http_pool_reclaimation_lag_seconds",
"Time it takes to reclaim unused connection pools",
// 1us -> 65ms
exponential_buckets(1e-6, 2.0, 16).unwrap(),
)
.unwrap()
});
pub static ENDPOINT_POOLS: Lazy<IntCounterPair> = Lazy::new(|| {
register_int_counter_pair!(
"proxy_http_pool_endpoints_registered_total",
"Number of endpoints we have registered pools for",
"proxy_http_pool_endpoints_unregistered_total",
"Number of endpoints we have unregistered pools for",
)
.unwrap()
});
pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"proxy_http_pool_opened_connections",
"Number of opened connections to a database.",
)
.unwrap()
});
#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started
@@ -234,3 +274,22 @@ pub static CONNECTING_ENDPOINTS: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
)
.unwrap()
});
pub static ERROR_BY_KIND: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_errors_total",
"Number of errors by a given classification",
&["type"],
)
.unwrap()
});
pub static ENDPOINT_ERRORS_BY_KIND: Lazy<HyperLogLogVec<32>> = Lazy::new(|| {
register_hll_vec!(
32,
"proxy_endpoints_affected_by_errors",
"Number of endpoints affected by errors of a given classification",
&["type"],
)
.unwrap()
});

View File

@@ -13,9 +13,10 @@ use crate::{
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
protocol2::WithClientIp,
proxy::{handshake::handshake, passthrough::proxy_pass},
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
@@ -28,14 +29,17 @@ use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
@@ -98,14 +102,14 @@ pub async fn task_main(
bail!("missing required client IP");
}
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
let res = handle_client(
config,
&mut ctx,
cancel_map,
@@ -113,7 +117,26 @@ pub async fn task_main(
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await
.await;
match res {
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
ctx.log();
Err(e.into())
}
Ok(None) => {
ctx.set_success();
ctx.log();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log();
p.proxy_pass().await
}
}
}
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
@@ -169,6 +192,37 @@ impl ClientMode {
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] handshake::HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
@@ -176,7 +230,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
@@ -193,11 +247,17 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.as_ref();
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
let do_handshake = handshake(stream, mode.handshake_tls(tls));
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancel_map
.cancel_session(cancel_key_data)
.await
.map(|()| None)?)
}
};
drop(pause);
let hostname = mode.hostname(stream.get_ref());
@@ -221,7 +281,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
.await?;
}
}
@@ -241,7 +301,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
return stream.throw_error(e).instrument(params_span).await?;
}
};
@@ -267,7 +327,13 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
Ok(Some(ProxyPassthrough {
client: stream,
compute: node,
aux,
req: _request_gauge,
conn: _client_gauge,
}))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
@@ -276,7 +342,7 @@ async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
) -> Result<(), std::io::Error> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());

View File

@@ -34,21 +34,6 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
node_info.invalidate().config
}
/// Try to connect to the compute node once.
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
async fn connect_to_compute_once(
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, compute::ConnectionError> {
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
@@ -75,13 +60,18 @@ impl ConnectMechanism for TcpMechanism<'_> {
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
async fn connect_once(
&self,
ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
connect_to_compute_once(ctx, node_info, timeout).await
let allow_self_signed_compute = node_info.allow_self_signed_compute;
node_info
.config
.connect(ctx, allow_self_signed_compute, timeout)
.await
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {

View File

@@ -1,15 +1,60 @@
use anyhow::{bail, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::{
cancellation::CancelMap,
config::TlsConfig,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
stream::{PqStream, Stream},
error::ReportableError,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("data is sent before server replied with EncryptionResponse")]
EarlyData,
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for HandshakeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
},
HandshakeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
HandshakeError::ReportedError(e) => e.get_error_kind(),
}
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
@@ -18,8 +63,7 @@ use crate::{
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
) -> Result<HandshakeData<S>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -49,14 +93,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
return Err(HandshakeError::EarlyData);
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
@@ -64,7 +108,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
@@ -73,23 +117,23 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
return stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User)
.await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
break Ok(HandshakeData::Startup(stream, params));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
break Ok(HandshakeData::Cancel(cancel_key_data));
}
}
}

View File

@@ -1,9 +1,11 @@
use crate::{
compute::PostgresConnection,
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::NUM_BYTES_PROXIED_COUNTER,
stream::Stream,
usage_metrics::{Ids, USAGE_METRICS},
};
use metrics::IntCounterPairGuard;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
@@ -11,14 +13,10 @@ use utils::measured_stream::MeasuredStream;
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
@@ -51,3 +49,18 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<S> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
proxy_pass(self.client, self.compute.stream, self.aux).await
}
}

View File

@@ -163,11 +163,11 @@ async fn dummy_proxy(
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let client = WithClientIp::new(client);
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
let mut stream = match handshake(client, tls.as_ref()).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -478,6 +478,9 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
}
fn helper_create_cached_node_info() -> CachedNodeInfo {

View File

@@ -35,12 +35,10 @@ async fn proxy_mitm(
tokio::spawn(async move {
// begin handshake with end_server
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
// process handshake with end_client
let (end_client, startup) =
handshake(client1, Some(&server_config1), &CancelMap::default())
.await
.unwrap()
.unwrap();
let (end_client, startup) = match handshake(client1, Some(&server_config1)).await.unwrap() {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
let (end_client, buf) = end_client.framed.into_inner();

View File

@@ -10,7 +10,7 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::UserFacingError;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
@@ -48,6 +48,18 @@ impl UserFacingError for Error {
}
}
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -2,6 +2,7 @@
//!
//! Handles both SQL over HTTP and SQL over Websockets.
mod backend;
mod conn_pool;
mod json;
mod sql_over_http;
@@ -18,11 +19,11 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
@@ -54,12 +55,13 @@ pub async fn task_main(
info!("websocket server has shut down");
}
let conn_pool = conn_pool::GlobalConnPool::new(config);
let conn_pool2 = Arc::clone(&conn_pool);
tokio::spawn(async move {
conn_pool2.gc_worker(StdRng::from_entropy()).await;
});
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
tokio::spawn(async move {
conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
@@ -73,6 +75,11 @@ pub async fn task_main(
}
});
let backend = Arc::new(PoolingBackend {
pool: Arc::clone(&conn_pool),
config,
});
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
None => {
@@ -102,11 +109,10 @@ pub async fn task_main(
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref();
let (io, _) = stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -118,8 +124,7 @@ pub async fn task_main(
};
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
@@ -130,12 +135,10 @@ pub async fn task_main(
request_handler(
req,
config,
tls_config,
conn_pool,
backend,
ws_connections,
cancel_map,
session_id,
sni_name,
peer_addr.ip(),
endpoint_rate_limiter,
)
@@ -200,12 +203,10 @@ where
async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
tls: &'static TlsConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Body>, ApiError> {
@@ -225,11 +226,11 @@ async fn request_handler(
ws_connections.spawn(
async move {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
let ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
if let Err(e) = websocket::serve_websocket(
config,
&mut ctx,
ctx,
websocket,
cancel_map,
host,
@@ -246,17 +247,9 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
sql_over_http::handle(
tls,
&config.http_config,
&mut ctx,
request,
sni_hostname,
conn_pool,
)
.await
sql_over_http::handle(config, ctx, request, backend).await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

Some files were not shown because too many files have changed in this diff Show More