Compare commits

...

49 Commits

Author SHA1 Message Date
Vlad Lazar
950df2668c fixup: hakari 2025-07-30 16:10:41 +01:00
Vlad Lazar
d9774e2c67 fixup: revert lock file changes 2025-07-30 10:27:25 +01:00
Vlad Lazar
dc25f1de90 Merge remote-tracking branch 'origin/main' into vlad/port-storcon-persistence 2025-07-29 19:06:18 +01:00
Vlad Lazar
2f9cc9a11e tests: remove tests that depend on libpq 2025-07-29 17:35:47 +01:00
Erik Grinaker
65d1be6e90 pageserver: route gRPC requests to child shards (#12702)
## Problem

During shard splits, each parent shard is split and removed
incrementally. Only when all parent shards have split is the split
committed and the compute notified. This can take several minutes for
large tenants. In the meanwhile, the compute will be sending requests to
the (now-removed) parent shards.

This was (mostly) not a problem for the libpq protocol, because it does
shard routing on the server-side. The compute just sends requests to
some Pageserver, and the server will figure out which local shard should
serve it.

It is a problem for the gRPC protocol, where the client explicitly says
which shard it's talking to.

Touches [LKB-191](https://databricks.atlassian.net/browse/LKB-191).
Requires #12772.

## Summary of changes

* Add server-side routing of gRPC requests to any local child shards if
the parent does not exist.
* Add server-side splitting of GetPage batch requests straddling
multiple child shards.
* Move the `GetPageSplitter` into `pageserver_page_api`.

I really don't like this approach, but it avoids making changes to the
split protocol. I could be convinced we should change the split protocol
instead, e.g. to keep the parent shard alive until the split commits and
the compute has been notified, but we can also do that as a later change
without blocking the communicator on it.
2025-07-29 16:28:57 +00:00
Suhas Thalanki
16eb8dda3d some compute ctl changes from hadron (#12760)
Some compute ctl changes from hadron
2025-07-29 16:01:56 +00:00
Heikki Linnakangas
bb32f1b3d0 Move 'criterion' to a dev-dependency (#12762)
It is only used in micro-benchmarks.
2025-07-29 15:35:00 +00:00
a-masterov
5585c32cee Disable autovacuum while running pg_repack test (#12755)
## Problem
Sometimes, the regression test of `pg_repack` fails due to an extra line
in the output.
The most probable cause of this is autovacuum.  
https://databricks.atlassian.net/browse/LKB-2637
## Summary of changes
Autovacuum is disabled during the test.

Co-authored-by: Alexey Masterov <alexey.masterov@databricks.com>
2025-07-29 15:34:02 +00:00
Krzysztof Szafrański
0ffdc98e20 [proxy] Classify "database not found" errors as user errors (#12603)
## Problem

If a user provides a wrong database name in the connection string, it
should be logged as a user error, not postgres error.

I found 4 different places where we log such errors:
1. `proxy/src/stream.rs:193`, e.g.:
```
{"timestamp":"2025-07-15T11:33:35.660026Z","level":"INFO","message":"forwarding error to user","fields":{"kind":"postgres","msg":"database \"[redacted]\" does not exist"},"spans":{"connect_request#9":{"protocol":"tcp","session_id":"ce1f2c90-dfb5-44f7-b9e9-8b8535e8b9b8","conn_info":"[redacted]","ep":"[redacted]","role":"[redacted]"}},"thread_id":22,"task_id":"370407867","target":"proxy::stream","src":"proxy/src/stream.rs:193","extract":{"ep":"[redacted]","session_id":"ce1f2c90-dfb5-44f7-b9e9-8b8535e8b9b8"}}
```
2. `proxy/src/pglb/mod.rs:137`, e.g.:
```
{"timestamp":"2025-07-15T11:37:44.340497Z","level":"WARN","message":"per-client task finished with an error: Couldn't connect to compute node: db error: FATAL: database \"[redacted]\" does not exist","spans":{"connect_request#8":{"protocol":"tcp","session_id":"763baaac-d039-4f4d-9446-c149e32660eb","conn_info":"[redacted]","ep":"[redacted]","role":"[redacted]"}},"thread_id":14,"task_id":"866658139","target":"proxy::pglb","src":"proxy/src/pglb/mod.rs:137","extract":{"ep":"[redacted]","session_id":"763baaac-d039-4f4d-9446-c149e32660eb"}}
```
3. `proxy/src/serverless/mod.rs:451`, e.g. (note that the error is
repeated 4 times — retries?):
```
{"timestamp":"2025-07-15T11:37:54.515891Z","level":"WARN","message":"error in websocket connection: Couldn't connect to compute node: db error: FATAL: database \"[redacted]\" does not exist: Couldn't connect to compute node: db error: FATAL: database \"[redacted]\" does not exist: db error: FATAL: database \"[redacted]\" does not exist: FATAL: database \"[redacted]\" does not exist","spans":{"http_conn#8":{"conn_id":"ec7780db-a145-4f0e-90df-0ba35f41b828"},"connect_request#9":{"protocol":"ws","session_id":"1eaaeeec-b671-4153-b1f4-247839e4b1c7","conn_info":"[redacted]","ep":"[redacted]","role":"[redacted]"}},"thread_id":10,"task_id":"366331699","target":"proxy::serverless","src":"proxy/src/serverless/mod.rs:451","extract":{"conn_id":"ec7780db-a145-4f0e-90df-0ba35f41b828","ep":"[redacted]","session_id":"1eaaeeec-b671-4153-b1f4-247839e4b1c7"}}
```
4. `proxy/src/serverless/sql_over_http.rs:219`, e.g.
```
{"timestamp":"2025-07-15T10:32:34.866603Z","level":"INFO","message":"forwarding error to user","fields":{"kind":"postgres","error":"could not connect to postgres in compute","msg":"database \"[redacted]\" does not exist"},"spans":{"http_conn#19":{"conn_id":"7da08203-5dab-45e8-809f-503c9019ec6b"},"connect_request#5":{"protocol":"http","session_id":"68387f1c-cbc8-45b3-a7db-8bb1c55ca809","conn_info":"[redacted]","ep":"[redacted]","role":"[redacted]"}},"thread_id":17,"task_id":"16432250","target":"proxy::serverless::sql_over_http","src":"proxy/src/serverless/sql_over_http.rs:219","extract":{"conn_id":"7da08203-5dab-45e8-809f-503c9019ec6b","ep":"[redacted]","session_id":"68387f1c-cbc8-45b3-a7db-8bb1c55ca809"}}
```

This PR directly addresses 1 and 4. I _think_ it _should_ also help with
2 and 3, although in those places we don't seem to log `kind`, so I'm
not quite sure. I'm also confused why in 3 the error is repeated
multiple times.

## Summary of changes

Resolves https://github.com/neondatabase/neon/issues/9440
2025-07-29 15:25:22 +00:00
HaoyuHuang
62d844e657 Add changes in spec apply (#12759)
## Problem
All changes are no-op. 

## Summary of changes
2025-07-29 15:22:04 +00:00
Alex Chi Z.
1bb434ab74 fix(test): test_readonly_node_gc compute needs time to acquire lease (#12747)
## Problem

Part of LKB-2368. Compute fails to obtain LSN lease in this test case.
There're many assumptions around how compute obtains the leases, and in
this particular test case, as the LSN lease length is only 8s (which is
shorter than the amount of time where pageserver can restart and compute
can reconnect in terms of force stop), it sometimes cause issues.

## Summary of changes

Add more sleeps around the test case to ensure it's stable at least. We
need to find a more reliable way to test this in the future.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-29 14:23:42 +00:00
Alex Chi Z.
dbde37c53a fix(safekeeper): retry if open segment fail (#12757)
## Problem

Fix LKB-2632.

The safekeeper wal read path does not seem to retry at all. This would
cause client read errors on the customer side.

## Summary of changes

- Retry on `safekeeper::wal_backup::read_object`.
- Note that this only retries on S3 HTTP connection errors. Subsequent
reads could fail, and that needs more refactors to make the retry
mechanism work across the path.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-29 14:20:43 +00:00
Heikki Linnakangas
5e3cb2ab07 Refactor LFC stats functions (#12696)
Split the functions into two parts: an internal function in file_cache.c
which returns an array of structs representing the result set, and
another function in neon.c with the glue code to expose it as a SQL
function. This is in preparation for the new communicator, which needs
to implement the same SQL functions, but getting the information from a
different place.

In the glue code, use the more modern Postgres way of building a result
set using a tuplestore.
2025-07-29 13:12:44 +00:00
Erik Grinaker
61f267d8f9 pageserver: only retry WaitForActiveTimeout during shard resolution (#12772)
## Problem

In https://github.com/neondatabase/neon/pull/12467, timeouts and retries
were added to `Cache::get` tenant shard resolution to paper over an
issue with read unavailability during shard splits. However, this
retries _all_ errors, including irrecoverable errors like `NotFound`.

This causes problems with gRPC child shard routing in #12702, which
targets specific shards with `ShardSelector::Known` and relies on prompt
`NotFound` errors to reroute requests to child shards. These retries
introduce a 1s delay for all reads during child routing.

The broader problem of read unavailability during shard splits is left
as future work, see https://databricks.atlassian.net/browse/LKB-672.

Touches #12702.
Touches [LKB-191](https://databricks.atlassian.net/browse/LKB-191).

## Summary of changes

* Change `TenantManager` to always return a concrete
`GetActiveTimelineError`.
* Only retry `WaitForActiveTimeout` errors.
* Lots of code unindentation due to the simplified error handling.

Out of caution, we do not gate the retries on `ShardSelector`, since
this can trigger other races. Improvements here are left as future work.
2025-07-29 12:33:02 +00:00
JC Grünhage
e2411818ef Add SBOMs and provenance attestations to container images (#12768)
## Problem
Given a container image it is difficult to figure out dependencies and
doesn't work automatically.

## Summary of changes
- Build all rust binaries with `cargo auditable`, to allow sbom scanners
to find it's dependencies.
- Adjust `attests` for `docker/build-push-action`, so that buildkit
creates sbom and provenance attestations.
- Dropping `--locked` for `rustfilt`, because `rustfilt` can't build
with locked dependencies[^5]

## Further details
Building with `cargo auditable`[^1] embeds a dependency list into Linux,
Windows, MacOS and WebAssembly artifacts. A bunch of tools support
discovering dependencies from this, among them `syft`[^2], which is used
by the BuildKit Syft scanner[^3] plugin. This BuildKit plugin is the
default[^4] used in docker for generating sbom attestations, but we're
making that default explicit by referencing the container image.
[^1]: https://github.com/rust-secure-code/cargo-auditable
[^2]: https://github.com/anchore/syft
[^3]: https://github.com/docker/buildkit-syft-scanner
[^4]:
https://docs.docker.com/build/metadata/attestations/sbom/#sbom-generator
[^5]: https://github.com/luser/rustfilt/issues/23
2025-07-29 12:12:14 +00:00
Dmitrii Kovalkov
58327cbba8 storcon: wait for the migration from the drained node in the draining loop (#12754)
## Problem
We have seen some errors in staging when the shard migration was
triggered by optimizations, and it was ongoing during draining the node
it was migrating from. It happens because the node draining loop only
waits for the migrations started by the drain loop itself. The ongoing
migrations are ignored.

Closes: https://databricks.atlassian.net/browse/LKB-1625

## Summary of changes
- Wait for the shard reconciliation during the drain if it is being
migrated from the drained node.
2025-07-29 11:58:31 +00:00
Heikki Linnakangas
568927a8a0 Remove unnecessary dependency to 'log' crate (#12763)
We use 'tracing' everywhere.
2025-07-29 11:08:22 +00:00
a-masterov
1ed7252950 Add a workaround for the clickhouse 24.9+ problem causing an error (#12767)
## Problem
We used ClickHouse v. 24.8, which is outdated, for logical replication
testing. We could miss some problems.
## Summary of changes
The version was updated to 25.6, with a workaround using the environment
variable `PGSSLCERT`.

Co-authored-by: Alexey Masterov <alexey.masterov@databricks.com>
2025-07-29 10:19:10 +00:00
Alexander Bayandin
30b57334ef test_lsn_lease_storcon: ignore ShardSplit warning in debug builds (#12770)
## Problem

`test_lsn_lease_storcon` might fail in debug builds due to slow
ShardSplit

## Summary of changes
- Make `test_lsn_lease_storcon ` test to ignore `.*Exclusive lock by
ShardSplit was held.*` warning in debug builds

Ref: https://databricks.slack.com/archives/C09254R641L/p1753777051481029
2025-07-29 09:47:39 +00:00
Vlad Lazar
e92ed85e9a Merge remote-tracking branch 'origin/main' into vlad/port-storcon-persistence 2025-07-29 10:45:21 +01:00
Vlad Lazar
07396d3fc7 fixup: massage imports for tests 2025-07-29 10:09:50 +01:00
Vlad Lazar
d5cd2de3cf fixup: allow unused code and fix blocking thread count 2025-07-29 10:09:50 +01:00
Heikki Linnakangas
d487ba2b9b Replace 'memoffset' crate with core functionality (#12761)
The `std::mem::offset_of` macro was introduced in Rust 1.77.0.

In the passing, mark the function as `const`, as suggested in the
comment. Not sure which compiler version that requires, but it works
with what have currently.
2025-07-29 08:01:31 +00:00
Conrad Ludgate
e7a1d5de94 proxy: cache for password hashing (#12011)
## Problem

Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can
get away with temporarily caching some steps so we only need fewer
rounds, which will save some CPU time.

## Summary of changes

The output of pbkdf2 is the XOR of the outputs of each iteration round,
eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of
the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the
cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`.
The suffix by itself is useless, which prevent's its use in brute-force
attacks should this cached memory leak.

We are also caching the full 4096 round hash in memory, which can be
used for brute-force attacks, where this suffix could be used to speed
it up. My hope/expectation is that since these will be in different
allocations, it makes any such memory exploitation much much harder.
Since the full hash cache might be invalidated while the suffix is
cached, I'm storing the timestamp of the computation as a way to
identity the match.

I also added `zeroize()` to clear the sensitive state from the
stack/heap.

For the most security conscious customers, we hope to roll out OIDC
soon, so they can disable passwords entirely.

---

The numbers for the threadpool were pretty random, but according to our
busiest region for sql-over-http, we only see about 150 unique endpoints
every minute. So storing ~100 of the most common endpoints for that
minute should be the vast majority of requests.

1 minute was chosen so we don't keep data in memory for too long.
2025-07-29 06:48:14 +00:00
Ivan Efremov
6be572177c chore: Fix nightly lints (#12746)
- Remove some unused code
- Use `is_multiple_of()` instead of '%'
- Collapse consecuative "if let" statements
- Elided lifetime fixes

It is enough just to review the code of your team
2025-07-28 21:36:30 +00:00
Alex Chi Z.
fe7a4e1ab6 fix(test): wait compaction in timeline offload test (#12673)
## Problem

close LKB-753. `test_pageserver_metrics_removed_after_offload` is
unstable and it sometimes leave the metrics behind after tenant
offloading. It turns out that we triggered an image compaction before
the offload and the job was stopped after the offload request was
completed.

## Summary of changes

Wait all background tasks to finish before checking the metrics.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-28 16:27:55 +00:00
Heikki Linnakangas
40cae8cc36 Fix misc typos and some cosmetic code cleanup (#12695) 2025-07-28 16:21:35 +00:00
Heikki Linnakangas
02fc8b7c70 Add compatibility macros for MyProcNumber and PGIOAlignedBlock (#12715)
There were a few uses of these already, so collect them to the
compatibility header to avoid the repetition and scattered #ifdefs.

The definition of MyProcNumber is a little different from what was used
before, but the end result is the same. (PGPROC->pgprocno values were
just assigned sequentially to all PGPROC array members, see
InitProcGlobal(). That's a bit silly, which is why it was removed in
v17.)
2025-07-28 15:05:36 +00:00
John Spray
60feb168e2 pageserver: decrease MAX_SHARDS in utilization (#12668)
## Problem

When tenants have a lot of timelines, the number of tenants that a
pageserver can comfortably handle goes down. Branching is much more
widely used in practice now than it was when this code was written, and
we generally run pageservers with a few thousand tenants (where each
tenant has many timelines), rather than the 10k-20k we might have done
historically.

This should really be something configurable, or a more direct proxy for
resource utilization (such as non-archived timeline count), but this
change should be a low effort improvement.

## Summary of changes

* Change the target shard count (MAX_SHARDS) to 2500 from 5000 when
calculating pageserver utilization (i.e. a 200% overcommit now
corresponds to 5000 shards, not 10000 shards)

Co-authored-by: John Spray <john.spray@databricks.com>
2025-07-28 13:50:18 +00:00
a-masterov
da596a5162 Update the versions for ClickHouse and Debezium (#12741)
## Problem
The test for logical replication used the year-old versions of
ClickHouse and Debezium so that we may miss problems related to
up-to-date versions.
## Summary of changes
The ClickHouse version has been updated to 24.8.
The Debezium version has been updated to the latest stable one,
3.1.3Final.
Some problems with locally running the Debezium test have been fixed.

---------

Co-authored-by: Alexey Masterov <alexey.masterov@databricks.com>
Co-authored-by: Alexander Bayandin <alexander@neon.tech>
2025-07-28 13:26:33 +00:00
Conrad Ludgate
effd6bf829 [proxy] add metrics for caches (#12752)
Exposes metrics for caches. LKB-2594

This exposes a high level namespace, `cache`, that all cache metrics can
be added to - this makes it easier to make library panels for the caches
as I understand it.

To calculate the current cache fill ratio, you could use the following
query:

```
(
    cache_inserted_total{cache="node_info"}
  - sum (cache_evicted_total{cache="node_info"}) without (cause)
)
  / cache_capacity{cache="node_info"}
```

To calculate the cache hit ratio, you could use the following query:

```
  cache_request_total{cache="node_info", outcome="hit"}
/ sum (cache_request_total{cache="node_info"}) without (outcome)
```
2025-07-28 10:41:49 +00:00
Tristan Partin
a6e0baf31a [BRC-1405] Mount databricks pg_hba and pg_ident from configmap (#12733)
## Problem

For certificate auth, we need to configure pg_hba and pg_ident for it to
work.

HCC needs to mount this config map to all pg compute pod.

## Summary of changes

Create `databricks_pg_hba` and `databricks_pg_ident` to configure where
the files are located on the pod. These configs are pass down to
`compute_ctl`. Compute_ctl uses these config to update `pg_hba.conf` and
`pg_ident.conf` file.

We append `include_if_exists {databricks_pg_hba}` to `pg_hba.conf` and
similarly to `pg_ident.conf`. So that it will refer to databricks config
file without much change to existing pg default config file.

---------

Co-authored-by: Jarupat Jisarojito <jarupat.jisarojito@databricks.com>
Co-authored-by: William Huang <william.huang@databricks.com>
Co-authored-by: HaoyuHuang <haoyu.huang.68@gmail.com>
2025-07-25 20:50:03 +00:00
Christian Schwarz
19b74b8837 fix(page_service): getpage requests don't hold applied_gc_cutoff_lsn guard (#12743)
Before this PR, getpage requests wouldn't hold the
`applied_gc_cutoff_lsn` guard until they were done.

Theoretical impact: if we’re not holding the `RcuReadGuard`, gc can
theoretically concurrently delete reconstruct data that we need to
reconstruct the page.

I don't think this practically occurs in production because the odds of
it happening are quite low, especially for primary read_write computes.
But RO replicas / standby_horizon relies on correct
`applied_gc_cutofff_lsn`, so, I'm fixing this as part of the work ok
replacing standby_horizon propagation mechanism with leases (LKB-88).

The change is feature-gated with a feature flag, and evaluated once when
entering `handle_pagestream` to avoid performance impact.

For observability, we add a field to the `handle_pagestream` span, and a
slow-log to the place in `gc_loop` where it waits for the in-flight
RcuReadGuard's to drain.

refs
- fixes https://databricks.atlassian.net/browse/LKB-2572
- standby_horizon leases epic:
https://databricks.atlassian.net/browse/LKB-2572

---------

Co-authored-by: Christian Schwarz <Christian Schwarz>
2025-07-25 20:25:04 +00:00
Folke Behrens
25718e324a proxy: Define service_info metric showing the run state (#12749)
## Problem

Monitoring dashboards show aggregates of all proxy instances, including
terminating ones. This can skew the results or make graphs less
readable. Also, alerts must be tuned to ignore certain signals from
terminating proxies.

## Summary of changes

Add a `service_info` metric currently with one label, `state`, showing
if an instance is in state `init`, `running`, or `terminating`. The
metric can be joined with other metrics to filter the presented time
series.
2025-07-25 18:27:21 +00:00
Dmitrii Kovalkov
ac8f44c70e tests: stop ps immediately in test_ps_unavailable_after_delete (#12728)
## Problem
test_ps_unavailable_after_delete is flaky. All test failures I've looked
at are because of ERROR log messages in pageserver, which happen because
storage controller tries runs a reconciliations during the graceful
shutdown of the pageserver.

I wasn't able to reproduce it locally, but I think stopping PS
immediately instead of gracefully should help. If not, we might just
silence those errors.

- Closes: https://databricks.atlassian.net/browse/LKB-745
2025-07-25 18:09:34 +00:00
Vlad Lazar
bab09fffb6 storcon: add persistence configuration options 2025-07-25 19:00:28 +01:00
Vlad Lazar
7bd73eba72 storcon: cut over to the hadron impl for establishing connection 2025-07-25 19:00:28 +01:00
Vlad Lazar
a1924e72ad storcon: port over hadron persistence changes 2025-07-25 19:00:22 +01:00
Conrad Ludgate
d09664f039 [proxy] replace TimedLru with moka (#12726)
LKB-2536 TimedLru is hard to maintain. Let's use moka instead. Stacked
on top of #12710.
2025-07-25 17:39:48 +00:00
Mikhail
6689d6fd89 LFC prewarm perftest fixes: use existing staging project (#12651)
https://github.com/neondatabase/cloud/issues/19011

- Prewarm config changes are not publicly available.
  Correct the test by using a pre-filled 50 GB project on staging
- Create extension neon with schema neon to fix read performance tests
on staging, error example in
https://neon-github-public-dev.s3.amazonaws.com/reports/main/16483462789/index.html#suites/3d632da6dda4a70f5b4bd24904ab444c/919841e331089fc4/
- Don't create extra endpoint in LFC prewarm performance tests
2025-07-25 16:56:41 +00:00
Tristan Partin
33b400beae [BRC-1425] Plumb through and set the requisite GUCs when starting the compute instance (#12732)
## Problem

We need the set the following Postgres GUCs to the correct value before
starting Postgres in the compute instance:

```
databricks.workspace_url
databricks.enable_databricks_identity_login
databricks.enable_sql_restrictions
```

## Summary of changes

Plumbed through `workspace_url` and other GUC settings via
`DatabricksSettings` in `ComputeSpec`. The spec is sent to the compute
instance when it starts up and the GUCs are written to `postgresql.conf`
before the postgres process is launched.

---------

Co-authored-by: Jarupat Jisarojito <jarupat.jisarojito@databricks.com>
Co-authored-by: William Huang <william.huang@databricks.com>
2025-07-25 15:20:05 +00:00
Tristan Partin
ca07f7dba5 Copy pg server cert and key to pgdata with correct permission (#12731)
## Problem

Copy certificate and key from secret mount directory to `pgdata`
directory where `postgres` is the owner and we can set the key
permission to 0600.

## Summary of changes

- Added new pgparam `pg_compute_tls_settings` to specify where k8s
secret for certificate and key are mounted.
- Added a new field to `ComputeSpec` called `databricks_settings`. This
is a struct that will be used to store any other settings that needs to
be propagate to Compute but should not be persisted to `ComputeSpec` in
the database.
- Then when the compute container start up, as part of `prepare_pgdata`
function, it will copied `server.key` and `server.crt` from k8s mounted
directory to `pgdata` directory.

## How is this tested?

Add unit tests.
Manual test via KIND

Co-authored-by: Jarupat Jisarojito <jarupat.jisarojito@databricks.com>
2025-07-25 15:05:05 +00:00
Vlad Lazar
b0dfe0ffa6 storcon: attempt all non-essential location config calls during reconciliations (#12745)
## Problem

We saw the following in the field:

Context and observations:
* The storage controller keeps track of the latest generations and the
pageserver that issued the latest generation in the database
* When the storage controller needs to proxy a request (e.g. timeline
creation) to the pageservers, it will find use the pageserver that
issued the latest generation from the db (generation_pageserver).
* pageserver-2.cell-2 got into a bad state and wasn't able to apply
location_config (e.g. detach a shard)

What happened:
1. pageserver-2.cell-2 was a secondary for our shard since we were not
able to detach it
2. control plane asked to detach a tenant (presumably because it was
idle)
a. In response storcon clears the generation_pageserver from the db and
attempts to detach all locations
b. it tries to detach pageserver-2.cell-2 first, but fails, which fails
the entire reconciliation leaving the good attached location still there
c. return success to cplane

3. control plane asks to re-attach the tenant
a. In response storcon performs a reconciliation
b. it finds that the observed state matches the intent (remember we did
not detach the primary at step(2))
c. skips incrementing the genration and setting the
generation_pageserver column

Now any requests that need to be proxied to pageservers and rely on the
generation_pageserver db column fail because that's not set

## Summary of changes

1. We do all non-essential location config calls (setting up
secondaries,
detaches) at the end of the reconciliation. Previously, we bailed out
of the reconciliation on the first failure. With this patch we attempt
all of the RPCs.
This allows the observed state to update even if another RPC failed for
unrelated reasons.

2. If the overall reconciliation failed, we don't want to remove nodes
from the
observed state as a safe-guard. With the previous patch, we'll get a
deletion delta to process, which would be ignored. Ignoring it is not
the right thing to do since it's out of sync with the db state.
Hence, on reconciliation failures map deletion from the observed state
to the uncertain state. Future reconciliation will query the node to
refresh their observed state.

Closes LKB-204
2025-07-25 14:03:17 +00:00
Erik Grinaker
185ead8395 pageserver: verify gRPC GetPages on correct shard (#12722)
Verify that gRPC `GetPageRequest` has been sent to the shard that owns
the pages. This avoid spurious `NotFound` errors if a compute misroutes
a request, which can appear scarier (e.g. data loss).

Touches [LKB-191](https://databricks.atlassian.net/browse/LKB-191).
2025-07-25 13:43:04 +00:00
Erik Grinaker
37e322438b pageserver: document gRPC compute accessibility (#12724)
Document that the Pageserver gRPC port is accessible by computes, and
should not provide internal services.

Touches [LKB-191](https://databricks.atlassian.net/browse/LKB-191).
2025-07-25 13:35:44 +00:00
Gustavo Bazan
fca2c32e59 [ci/docker] task: Apply some quick wins for tools dockerfile (#12740)
## Problem

The Dockerfile for build tools has some small issues that are easy to
fix to make it follow some of docker best practices

## Summary of changes

Apply some small quick wins on the Dockerfile for build tools

- Usage of apt-get over apt
- usage of --no-cache-dir for pip install
2025-07-25 12:39:01 +00:00
Conrad Ludgate
d19aebcf12 [proxy] introduce moka for the project-info cache (#12710)
## Problem

LKB-2502 The garbage collection of the project info cache is garbage. 

What we observed: If we get unlucky, we might throw away a very hot
entry if the cache is full. The GC loop is dependent on getting a lucky
shard of the projects2ep table that clears a lot of cold entries. The GC
does not take into account active use, and the interval it runs at is
too sparse to do any good.

Can we switch to a proper cache implementation?

Complications:
1. We need to invalidate by project/account.
2. We need to expire based on `retry_delay_ms`.

## Summary of changes

1. Replace `retry_delay_ms: Duration` with `retry_at: Instant` when
deserializing.
2. Split the EndpointControls from the RoleControls into two different
caches.
3. Introduce an expiry policy based on error retry info.
4. Introduce `moka` as a dependency, replacing our `TimedLru`.

See the follow up PR for changing all TimedLru instances to use moka:
#12726.
2025-07-25 11:40:47 +00:00
Conrad Ludgate
a70a5bccff move subzero_core to proxy libs (#12742)
We have a dedicated libs folder for proxy related libraries. Let's move
the subzero_core stub there.
2025-07-25 10:44:28 +00:00
Conrad Ludgate
d9cedb4a95 [tokio-postgres] fix regression in buffer reuse (#12739)
Follow up to #12701, which introduced a new regression. When profiling
locally I noticed that writes have the tendency to always reallocate. On
investigation I found that even if the `Connection`'s write buffer is
empty, if it still shares the same data pointer as the `Client`'s write
buffer then the client cannot reclaim it.

The best way I found to fix this is to just drop the `Connection`'s
write buffer each time we fully flush it.

Additionally, I remembered that `BytesMut` has an `unsplit` method which
is allows even better sharing over the previous optimisation I had when
'encoding'.
2025-07-25 09:03:21 +00:00
122 changed files with 3759 additions and 1719 deletions

View File

@@ -31,7 +31,7 @@ config-variables:
- NEON_PROD_AWS_ACCOUNT_ID
- PGREGRESS_PG16_PROJECT_ID
- PGREGRESS_PG17_PROJECT_ID
- PREWARM_PGBENCH_SIZE
- PREWARM_PROJECT_ID
- REMOTE_STORAGE_AZURE_CONTAINER
- REMOTE_STORAGE_AZURE_REGION
- SLACK_CICD_CHANNEL_ID

View File

@@ -418,7 +418,7 @@ jobs:
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }}
PROJECT_ID: ${{ vars.PREWARM_PROJECT_ID }}
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 17
TEST_OUTPUT: /tmp/test_output

View File

@@ -146,7 +146,9 @@ jobs:
with:
file: build-tools/Dockerfile
context: .
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
build-args: |

View File

@@ -634,7 +634,9 @@ jobs:
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: Dockerfile
@@ -747,7 +749,9 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: compute/compute-node.Dockerfile
@@ -766,7 +770,9 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: compute/compute-node.Dockerfile

View File

@@ -48,8 +48,20 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
generate-ch-tmppw:
runs-on: ubuntu-22.04
outputs:
tmp_val: ${{ steps.pwgen.outputs.tmp_val }}
steps:
- name: Generate a random password
id: pwgen
run: |
set +x
p=$(dd if=/dev/random bs=14 count=1 2>/dev/null | base64)
echo tmp_val="${p//\//}" >> "${GITHUB_OUTPUT}"
test-logical-replication:
needs: [ build-build-tools-image ]
needs: [ build-build-tools-image, generate-ch-tmppw ]
runs-on: ubuntu-22.04
container:
@@ -60,16 +72,21 @@ jobs:
options: --init --user root
services:
clickhouse:
image: clickhouse/clickhouse-server:24.6.3.64
image: clickhouse/clickhouse-server:25.6
env:
CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }}
PGSSLCERT: /tmp/postgresql.crt
ports:
- 9000:9000
- 8123:8123
zookeeper:
image: quay.io/debezium/zookeeper:2.7
image: quay.io/debezium/zookeeper:3.1.3.Final
ports:
- 2181:2181
- 2888:2888
- 3888:3888
kafka:
image: quay.io/debezium/kafka:2.7
image: quay.io/debezium/kafka:3.1.3.Final
env:
ZOOKEEPER_CONNECT: "zookeeper:2181"
KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092
@@ -79,7 +96,7 @@ jobs:
ports:
- 9092:9092
debezium:
image: quay.io/debezium/connect:2.7
image: quay.io/debezium/connect:3.1.3.Final
env:
BOOTSTRAP_SERVERS: kafka:9092
GROUP_ID: 1
@@ -125,6 +142,7 @@ jobs:
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }}
CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }}
- name: Delete Neon Project
if: always()

203
Cargo.lock generated
View File

@@ -211,11 +211,11 @@ dependencies = [
[[package]]
name = "async-lock"
version = "3.2.0"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c"
checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18"
dependencies = [
"event-listener 4.0.0",
"event-listener 5.4.0",
"event-listener-strategy",
"pin-project-lite",
]
@@ -1404,9 +1404,9 @@ dependencies = [
[[package]]
name = "concurrent-queue"
version = "2.3.0"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
@@ -2232,9 +2232,9 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
[[package]]
name = "event-listener"
version = "4.0.0"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae"
checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
@@ -2243,11 +2243,11 @@ dependencies = [
[[package]]
name = "event-listener-strategy"
version = "0.4.0"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3"
checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93"
dependencies = [
"event-listener 4.0.0",
"event-listener 5.4.0",
"pin-project-lite",
]
@@ -2516,6 +2516,20 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "304de19db7028420975a296ab0fcbbc8e69438c4ed254a1e41e2a7f37d5f0e0a"
[[package]]
name = "generator"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827"
dependencies = [
"cc",
"cfg-if",
"libc",
"log",
"rustversion",
"windows 0.61.3",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -2834,7 +2848,7 @@ checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba"
dependencies = [
"cfg-if",
"libc",
"windows",
"windows 0.52.0",
]
[[package]]
@@ -3105,7 +3119,7 @@ dependencies = [
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
"windows-core 0.52.0",
]
[[package]]
@@ -3656,6 +3670,19 @@ version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "loom"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
dependencies = [
"cfg-if",
"generator",
"scoped-tls",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "lru"
version = "0.12.3"
@@ -3872,6 +3899,25 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "moka"
version = "0.12.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
"crossbeam-utils",
"loom",
"parking_lot 0.12.1",
"portable-atomic",
"rustc_version",
"smallvec",
"tagptr",
"thiserror 1.0.69",
"uuid",
]
[[package]]
name = "multimap"
version = "0.8.3"
@@ -5031,8 +5077,6 @@ dependencies = [
"crc32c",
"criterion",
"env_logger",
"log",
"memoffset 0.9.0",
"once_cell",
"postgres",
"postgres_ffi_types",
@@ -5385,7 +5429,6 @@ dependencies = [
"futures",
"gettid",
"hashbrown 0.14.5",
"hashlink",
"hex",
"hmac",
"hostname",
@@ -5407,6 +5450,7 @@ dependencies = [
"lasso",
"measured",
"metrics",
"moka",
"once_cell",
"opentelemetry",
"ouroboros",
@@ -5473,6 +5517,7 @@ dependencies = [
"workspace_hack",
"x509-cert",
"zerocopy 0.8.24",
"zeroize",
]
[[package]]
@@ -6420,6 +6465,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -7269,6 +7320,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "tagptr"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "tar"
version = "0.4.40"
@@ -8638,10 +8695,32 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [
"windows-core",
"windows-core 0.52.0",
"windows-targets 0.52.6",
]
[[package]]
name = "windows"
version = "0.61.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893"
dependencies = [
"windows-collections",
"windows-core 0.61.2",
"windows-future",
"windows-link",
"windows-numerics",
]
[[package]]
name = "windows-collections"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8"
dependencies = [
"windows-core 0.61.2",
]
[[package]]
name = "windows-core"
version = "0.52.0"
@@ -8651,6 +8730,86 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-core"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-future"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e"
dependencies = [
"windows-core 0.61.2",
"windows-link",
"windows-threading",
]
[[package]]
name = "windows-implement"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "windows-interface"
version = "0.59.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "windows-link"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
[[package]]
name = "windows-numerics"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1"
dependencies = [
"windows-core 0.61.2",
"windows-link",
]
[[package]]
name = "windows-result"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@@ -8709,6 +8868,15 @@ dependencies = [
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-threading"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
@@ -8845,6 +9013,8 @@ dependencies = [
"clap",
"clap_builder",
"const-oid",
"crossbeam-epoch",
"crossbeam-utils",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
@@ -8890,6 +9060,7 @@ dependencies = [
"once_cell",
"p256 0.13.2",
"parquet",
"portable-atomic",
"prettyplease",
"proc-macro2",
"prost 0.13.5",

View File

@@ -46,10 +46,10 @@ members = [
"libs/proxy/json",
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/subzero_core",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
"proxy/subzero_core",
]
[workspace.package]
@@ -135,7 +135,7 @@ lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
memoffset = "0.9"
moka = { version = "0.12", features = ["sync"] }
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
# Do not update to >= 7.0.0, at least. The update will have a significant impact
# on compute startup metrics (start_postgres_ms), >= 25% degradation.
@@ -233,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"
## TODO replace this with tracing
env_logger = "0.11"

View File

@@ -103,7 +103,7 @@ RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo auditable build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \

View File

@@ -39,13 +39,13 @@ COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
set -e && \
apt update && \
apt install -y --no-install-recommends \
apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates wget gpg && \
wget -qO - https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /usr/share/keyrings/postgresql-keyring.gpg && \
echo "deb [signed-by=/usr/share/keyrings/postgresql-keyring.gpg] http://apt.postgresql.org/pub/repos/apt bookworm-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \
apt-get update && \
apt install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
build-essential \
autotools-dev \
libedit-dev \
@@ -89,8 +89,7 @@ RUN useradd -ms /bin/bash nonroot -b /home
# Use strict mode for bash to catch errors early
SHELL ["/bin/bash", "-euo", "pipefail", "-c"]
RUN mkdir -p /pgcopydb/bin && \
mkdir -p /pgcopydb/lib && \
RUN mkdir -p /pgcopydb/{bin,lib} && \
chmod -R 755 /pgcopydb && \
chown -R nonroot:nonroot /pgcopydb
@@ -106,8 +105,8 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
&& apt update \
&& apt install -y \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
autoconf \
automake \
bison \
@@ -183,22 +182,22 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/
ENV LLVM_VERSION=20
RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \
&& apt update \
&& apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \
&& apt-get update \
&& apt-get install -y --no-install-recommends clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& apt-get install -y --no-install-recommends nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
&& apt update \
&& apt install -y docker-ce docker-ce-cli \
&& apt-get update \
&& apt-get install -y --no-install-recommends docker-ce docker-ce-cli \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Configure sudo & docker
@@ -215,12 +214,11 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws
# Mold: A Modern Linker
ENV MOLD_VERSION=v2.37.1
RUN set -e \
&& git clone https://github.com/rui314/mold.git \
&& git clone -b "${MOLD_VERSION}" --depth 1 https://github.com/rui314/mold.git \
&& mkdir mold/build \
&& cd mold/build \
&& git checkout ${MOLD_VERSION} \
&& cd mold/build \
&& cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \
&& cmake --build . -j $(nproc) \
&& cmake --build . -j "$(nproc)" \
&& cmake --install . \
&& cd .. \
&& rm -rf mold
@@ -254,7 +252,7 @@ ENV ICU_VERSION=67.1
ENV ICU_PREFIX=/usr/local/icu
# Download and build static ICU
RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \
RUN wget -O "/tmp/libicu-${ICU_VERSION}.tgz" https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \
echo "94a80cd6f251a53bd2a997f6f1b5ac6653fe791dfab66e1eb0227740fb86d5dc /tmp/libicu-${ICU_VERSION}.tgz" | sha256sum --check && \
mkdir /tmp/icu && \
pushd /tmp/icu && \
@@ -265,8 +263,7 @@ RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/re
make install && \
popd && \
rm -rf icu && \
rm -f /tmp/libicu-${ICU_VERSION}.tgz && \
popd
rm -f /tmp/libicu-${ICU_VERSION}.tgz
# Switch to nonroot user
USER nonroot:nonroot
@@ -279,19 +276,19 @@ ENV PYTHON_VERSION=3.11.12 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
&& cd $HOME \
&& cd "$HOME" \
&& curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \
&& chmod +x pyenv-installer \
&& ./pyenv-installer \
&& export PYENV_ROOT=/home/nonroot/.pyenv \
&& export PATH="$PYENV_ROOT/bin:$PATH" \
&& export PATH="$PYENV_ROOT/shims:$PATH" \
&& pyenv install ${PYTHON_VERSION} \
&& pyenv global ${PYTHON_VERSION} \
&& pyenv install "${PYTHON_VERSION}" \
&& pyenv global "${PYTHON_VERSION}" \
&& python --version \
&& pip install --upgrade pip \
&& pip install --no-cache-dir --upgrade pip \
&& pip --version \
&& pip install pipenv wheel poetry
&& pip install --no-cache-dir pipenv wheel poetry
# Switch to nonroot user (again)
USER nonroot:nonroot
@@ -302,6 +299,7 @@ WORKDIR /home/nonroot
ENV RUSTC_VERSION=1.88.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG CARGO_AUDITABLE_VERSION=0.7.0
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.36
ARG CARGO_DENY_VERSION=0.18.2
@@ -317,14 +315,16 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" && \
cargo auditable install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" --force && \
cargo auditable install rustfilt --version "${RUSTFILT_VERSION}" && \
cargo auditable install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \
cargo auditable install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \
cargo auditable install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \
cargo auditable install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \
cargo auditable install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \
cargo auditable install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

View File

@@ -1,5 +1,11 @@
commit 5eb393810cf7c7bafa4e394dad2e349e2a8cb2cb
Author: Alexey Masterov <alexey.masterov@databricks.com>
Date: Mon Jul 28 18:11:02 2025 +0200
Patch for pg_repack
diff --git a/regress/Makefile b/regress/Makefile
index bf6edcb..89b4c7f 100644
index bf6edcb..110e734 100644
--- a/regress/Makefile
+++ b/regress/Makefile
@@ -17,7 +17,7 @@ INTVERSION := $(shell echo $$(($$(echo $(VERSION).0 | sed 's/\([[:digit:]]\{1,\}
@@ -7,18 +13,36 @@ index bf6edcb..89b4c7f 100644
#
-REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper tablespace get_order_by trigger
+REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger
+REGRESS := init-extension noautovacuum repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger autovacuum
USE_PGXS = 1 # use pgxs if not in contrib directory
PGXS := $(shell $(PG_CONFIG) --pgxs)
diff --git a/regress/expected/init-extension.out b/regress/expected/init-extension.out
index 9f2e171..f6e4f8d 100644
--- a/regress/expected/init-extension.out
+++ b/regress/expected/init-extension.out
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/expected/autovacuum.out b/regress/expected/autovacuum.out
new file mode 100644
index 0000000..e7f2363
--- /dev/null
+++ b/regress/expected/autovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/noautovacuum.out b/regress/expected/noautovacuum.out
new file mode 100644
index 0000000..fc7978e
--- /dev/null
+++ b/regress/expected/noautovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/nosuper.out b/regress/expected/nosuper.out
index 8d0a94e..63b68bf 100644
--- a/regress/expected/nosuper.out
@@ -50,14 +74,22 @@ index 8d0a94e..63b68bf 100644
INFO: repacking table "public.tbl_cluster"
ERROR: query failed: ERROR: current transaction is aborted, commands ignored until end of transaction block
DETAIL: query was: RESET lock_timeout
diff --git a/regress/sql/init-extension.sql b/regress/sql/init-extension.sql
index 9f2e171..f6e4f8d 100644
--- a/regress/sql/init-extension.sql
+++ b/regress/sql/init-extension.sql
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/sql/autovacuum.sql b/regress/sql/autovacuum.sql
new file mode 100644
index 0000000..a8eda63
--- /dev/null
+++ b/regress/sql/autovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
diff --git a/regress/sql/noautovacuum.sql b/regress/sql/noautovacuum.sql
new file mode 100644
index 0000000..13d4836
--- /dev/null
+++ b/regress/sql/noautovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
diff --git a/regress/sql/nosuper.sql b/regress/sql/nosuper.sql
index 072f0fa..dbe60f8 100644
--- a/regress/sql/nosuper.sql

View File

@@ -82,6 +82,15 @@ struct Cli {
#[arg(long, default_value_t = 3081)]
pub internal_http_port: u16,
/// Backwards-compatible --http-port for Hadron deployments. Functionally the
/// same as --external-http-port.
#[arg(
long,
conflicts_with = "external_http_port",
conflicts_with = "internal_http_port"
)]
pub http_port: Option<u16>,
#[arg(short = 'D', long, value_name = "DATADIR")]
pub pgdata: String,
@@ -181,6 +190,26 @@ impl Cli {
}
}
// Hadron helpers to get compatible compute_ctl http ports from Cli. The old `--http-port`
// arg is used and acts the same as `--external-http-port`. The internal http port is defined
// to be http_port + 1. Hadron runs in the dblet environment which uses the host network, so
// we need to be careful with the ports to choose.
fn get_external_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli.http_port.unwrap_or(cli.external_http_port);
}
cli.external_http_port
}
fn get_internal_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli
.http_port
.map(|p| p + 1)
.unwrap_or(cli.internal_http_port);
}
cli.internal_http_port
}
fn main() -> Result<()> {
let cli = Cli::parse();
@@ -205,13 +234,18 @@ fn main() -> Result<()> {
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
installed_extensions::initialize_metrics();
hadron_metrics::initialize_metrics();
if cli.lakebase_mode {
installed_extensions::initialize_metrics();
hadron_metrics::initialize_metrics();
}
let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?;
let config = get_config(&cli)?;
let external_http_port = get_external_http_port(&cli);
let internal_http_port = get_internal_http_port(&cli);
let compute_node = ComputeNode::new(
ComputeNodeParams {
compute_id: cli.compute_id,
@@ -220,8 +254,8 @@ fn main() -> Result<()> {
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
external_http_port: cli.external_http_port,
internal_http_port: cli.internal_http_port,
external_http_port,
internal_http_port,
remote_ext_base_url: cli.remote_ext_base_url.clone(),
resize_swap_on_bind: cli.resize_swap_on_bind,
set_disk_quota_for_fs: cli.set_disk_quota_for_fs,

View File

@@ -6,7 +6,8 @@ use compute_api::responses::{
LfcPrewarmState, PromoteState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption,
PageserverProtocol, PgIdent, Role,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -413,6 +414,130 @@ struct StartVmMonitorResult {
vm_monitor: Option<JoinHandle<Result<()>>>,
}
// BEGIN_HADRON
/// This function creates roles that are used by Databricks.
/// These roles are not needs to be botostrapped at PG Compute provisioning time.
/// The auth method for these roles are configured in databricks_pg_hba.conf in universe repository.
pub(crate) fn create_databricks_roles() -> Vec<String> {
let roles = vec![
// Role for prometheus_stats_exporter
Role {
name: "databricks_monitor".to_string(),
// This uses "local" connection and auth method for that is "trust", so no password is needed.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "IN ROLE pg_monitor".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore control plane
Role {
name: "databricks_control_plane".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "SUPERUSER".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore httpgateway.
Role {
name: "databricks_gateway".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: None,
},
];
roles
.into_iter()
.map(|role| {
let query = format!(
r#"
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM pg_catalog.pg_roles WHERE rolname = '{}')
THEN
CREATE ROLE {} {};
END IF;
END
$$;"#,
role.name,
role.name.pg_quote(),
role.to_pg_options(),
);
query
})
.collect()
}
/// Databricks-specific environment variables to be passed to the `postgres` sub-process.
pub struct DatabricksEnvVars {
/// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check
/// the token scopes of internal auth tokens.
pub endpoint_id: String,
/// Hostname of the Databricks workspace URL this compute instance belongs to.
/// Used by postgres to verify Databricks PAT tokens.
pub workspace_host: String,
pub lakebase_mode: bool,
}
impl DatabricksEnvVars {
pub fn new(
compute_spec: &ComputeSpec,
compute_id: Option<&String>,
instance_id: Option<String>,
lakebase_mode: bool,
) -> Self {
let endpoint_id = if let Some(instance_id) = instance_id {
// Use instance_id as endpoint_id if it is set. This code path is for PuPr model.
instance_id
} else {
// Use compute_id as endpoint_id if instance_id is not set. The code path is for PrPr model.
// compute_id is a string format of "{endpoint_id}/{compute_idx}"
// endpoint_id is a uuid. We only need to pass down endpoint_id to postgres.
// Panics if compute_id is not set or not in the expected format.
compute_id.unwrap().split('/').next().unwrap().to_string()
};
let workspace_host = compute_spec
.databricks_settings
.as_ref()
.map(|s| s.databricks_workspace_host.clone())
.unwrap_or("".to_string());
Self {
endpoint_id,
workspace_host,
lakebase_mode,
}
}
/// Constants for the names of Databricks-specific postgres environment variables.
const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID";
const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST";
/// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`.
pub fn to_env_var_list(self) -> Vec<(String, String)> {
if !self.lakebase_mode {
// In neon env, we don't need to pass down the env vars to postgres.
return vec![];
}
vec![
(
Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(),
self.endpoint_id.clone(),
),
(
Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(),
self.workspace_host.clone(),
),
]
}
}
impl ComputeNode {
pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result<Self> {
let connstr = params.connstr.as_str();
@@ -449,7 +574,11 @@ impl ComputeNode {
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
if params.lakebase_mode {
ComputeNode::set_spec(&params, &mut new_state, pspec);
} else {
new_state.pspec = Some(pspec);
}
}
Ok(ComputeNode {
@@ -1047,7 +1176,14 @@ impl ComputeNode {
// If it is something different then create_dir() will error out anyway.
let pgdata = &self.params.pgdata;
let _ok = fs::remove_dir_all(pgdata);
fs::create_dir(pgdata)?;
if self.params.lakebase_mode {
// Ignore creation errors if the directory already exists (e.g. mounting it ahead of time).
// If it is something different then PG startup will error out anyway.
let _ok = fs::create_dir(pgdata);
} else {
fs::create_dir(pgdata)?;
}
fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
Ok(())
@@ -1411,6 +1547,8 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
let tls_config = self.tls_config(&pspec.spec);
let databricks_settings = spec.databricks_settings.as_ref();
let postgres_port = self.params.connstr.port();
// Remove/create an empty pgdata directory and put configuration there.
self.create_pgdata()?;
@@ -1418,8 +1556,11 @@ impl ComputeNode {
pgdata_path,
&self.params,
&pspec.spec,
postgres_port,
self.params.internal_http_port,
tls_config,
databricks_settings,
self.params.lakebase_mode,
)?;
// Syncing safekeepers is only safe with primary nodes: if a primary
@@ -1459,8 +1600,28 @@ impl ComputeNode {
)
})?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path, None)?;
if let Some(settings) = databricks_settings {
copy_tls_certificates(
&settings.pg_compute_tls_settings.key_file,
&settings.pg_compute_tls_settings.cert_file,
pgdata_path,
)?;
// Update pg_hba.conf received with basebackup including additional databricks settings.
update_pg_hba(pgdata_path, Some(&settings.databricks_pg_hba))?;
update_pg_ident(pgdata_path, Some(&settings.databricks_pg_ident))?;
} else {
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path, None)?;
}
if let Some(databricks_settings) = spec.databricks_settings.as_ref() {
copy_tls_certificates(
&databricks_settings.pg_compute_tls_settings.key_file,
&databricks_settings.pg_compute_tls_settings.cert_file,
pgdata_path,
)?;
}
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
@@ -1501,7 +1662,7 @@ impl ComputeNode {
// symlink doesn't affect anything.
//
// See https://github.com/neondatabase/autoscaling/issues/800
std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?;
std::fs::remove_dir_all(pgdata_path.join("pg_dynshmem"))?;
symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?;
match spec.mode {
@@ -1516,6 +1677,12 @@ impl ComputeNode {
/// Start and stop a postgres process to warm up the VM for startup.
pub fn prewarm_postgres_vm_memory(&self) -> Result<()> {
if self.params.lakebase_mode {
// We are running in Hadron mode. Disabling this prewarming step for now as it could run
// into dblet port conflicts and also doesn't add much value with our current infra.
info!("Skipping postgres prewarming in Hadron mode");
return Ok(());
}
info!("prewarming VM memory");
// Create pgdata
@@ -1573,14 +1740,36 @@ impl ComputeNode {
pub fn start_postgres(&self, storage_auth_token: Option<String>) -> Result<PostgresHandle> {
let pgdata_path = Path::new(&self.params.pgdata);
let env_vars: Vec<(String, String)> = if self.params.lakebase_mode {
let databricks_env_vars = {
let state = self.state.lock().unwrap();
let spec = &state.pspec.as_ref().unwrap().spec;
DatabricksEnvVars::new(
spec,
Some(&self.params.compute_id),
self.params.instance_id.clone(),
self.params.lakebase_mode,
)
};
info!(
"Starting Postgres for databricks endpoint id: {}",
&databricks_env_vars.endpoint_id
);
let mut env_vars = databricks_env_vars.to_env_var_list();
env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t)));
env_vars
} else if let Some(storage_auth_token) = &storage_auth_token {
vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())]
} else {
vec![]
};
// Run postgres as a child process.
let mut pg = maybe_cgexec(&self.params.pgbin)
.args(["-D", &self.params.pgdata])
.envs(if let Some(storage_auth_token) = &storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.envs(env_vars)
.stderr(Stdio::piped())
.spawn()
.expect("cannot start postgres process");
@@ -1732,7 +1921,15 @@ impl ComputeNode {
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
if self.params.lakebase_mode {
// Set a 2-minute statement_timeout for the session applying config. The individual SQL statements
// used in apply_spec_sql() should not take long (they are just creating users and installing
// extensions). If any of them are stuck for an extended period of time it usually indicates a
// pageserver connectivity problem and we should bail out.
conf.options("-c statement_timeout=2min");
}
let conf = Arc::new(conf);
let spec = Arc::new(
@@ -1883,12 +2080,16 @@ impl ComputeNode {
// Write new config
let pgdata_path = Path::new(&self.params.pgdata);
let postgres_port = self.params.connstr.port();
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
postgres_port,
self.params.internal_http_port,
tls_config,
spec.databricks_settings.as_ref(),
self.params.lakebase_mode,
)?;
self.pg_reload_conf()?;
@@ -2046,7 +2247,17 @@ impl ComputeNode {
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
_ => Path::new(&self.params.pgdata),
// BEGIN HADRON
// NB: Read core dump files from a fixed location outside of
// the data directory since `compute_ctl` wipes the data directory
// across container restarts.
_ => {
if self.params.lakebase_mode {
Path::new("/databricks/logs/brickstore")
} else {
Path::new(&self.params.pgdata)
}
} // END HADRON
};
// Collect core dump paths if any
@@ -2359,7 +2570,7 @@ LIMIT 100",
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && !s.is_empty())
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.map(str::to_string)
.collect();
}
@@ -2378,7 +2589,7 @@ LIMIT 100",
if let Some(libs) = shared_preload_libraries_line.split("='").nth(1) {
preload_libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && !s.is_empty())
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.map(str::to_string)
.collect();
}

View File

@@ -7,11 +7,14 @@ use std::io::prelude::*;
use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use compute_api::spec::{
ComputeAudit, ComputeMode, ComputeSpec, DatabricksSettings, GenericOption,
};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize,
escape_conf_value,
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
@@ -40,12 +43,16 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
}
/// Create or completely rewrite configuration file specified by `path`
#[allow(clippy::too_many_arguments)]
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
postgres_port: Option<u16>,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
databricks_settings: Option<&DatabricksSettings>,
lakebase_mode: bool,
) -> Result<()> {
let path = pgdata_path.join("postgresql.conf");
// File::create() destroys the file content if it exists.
@@ -285,6 +292,24 @@ pub fn write_postgres_conf(
writeln!(file, "log_destination='stderr,syslog'")?;
}
if lakebase_mode {
// Explicitly set the port based on the connstr, overriding any previous port setting.
// Note: It is important that we don't specify a different port again after this.
let port = postgres_port.expect("port must be present in connstr");
writeln!(file, "port = {port}")?;
// This is databricks specific settings.
// This should be at the end of the file but before `compute_ctl_temp_override.conf` below
// so that it can override any settings above.
// `compute_ctl_temp_override.conf` is intended to override any settings above during specific operations.
// To prevent potential breakage in the future, we keep it above `compute_ctl_temp_override.conf`.
writeln!(file, "# Databricks settings start")?;
if let Some(settings) = databricks_settings {
writeln!(file, "{}", settings.as_pg_settings())?;
}
writeln!(file, "# Databricks settings end")?;
}
// This is essential to keep this line at the end of the file,
// because it is intended to override any settings above.
writeln!(file, "include_if_exists = 'compute_ctl_temp_override.conf'")?;

View File

@@ -142,7 +142,7 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) ->
// Update pg_hba to contains databricks specfic settings before adding neon settings
// PG uses the first record that matches to perform authentication, so we need to have
// our rules before the default ones from neon.
// See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html
// See https://www.postgresql.org/docs/current/auth-pg-hba-conf.html
if let Some(databricks_pg_hba) = databricks_pg_hba {
if config::line_in_file(
&pghba_path,

View File

@@ -13,17 +13,19 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState, create_databricks_roles};
use crate::hadron_metrics::COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS;
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
AddDatabricksGrants, AlterDatabricksRoles, CreateAndAlterDatabases, CreateAndAlterRoles,
CreateAvailabilityCheck, CreateDatabricksMisc, CreateDatabricksRoles, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
HandleDatabricksAuthExtension, HandleNeonExtension, HandleOtherExtensions,
RenameAndDeleteDatabases, RenameRoles, RunInEachDatabase,
};
use crate::spec_apply::PerDatabasePhase::{
ChangeSchemaPerms, DeleteDBRoleReferences, DropLogicalSubscriptions,
@@ -166,6 +168,7 @@ impl ComputeNode {
concurrency_token.clone(),
db,
[DropLogicalSubscriptions].to_vec(),
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -186,15 +189,33 @@ impl ComputeNode {
};
}
for phase in [
CreatePrivilegedRole,
let phases = if self.params.lakebase_mode {
vec![
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
] {
]
} else {
vec![
CreatePrivilegedRole,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
]
};
for phase in phases {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
@@ -203,6 +224,7 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -254,6 +276,7 @@ impl ComputeNode {
concurrency_token.clone(),
db,
phases,
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -265,12 +288,28 @@ impl ComputeNode {
handle.await??;
}
let mut phases = vec![
let mut phases = if self.params.lakebase_mode {
vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
DropRoles,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
]
} else {
vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
CreateAvailabilityCheck,
DropRoles,
];
]
};
// This step depends on CreateSchemaNeon
if spec.drop_subscriptions_before_start && !drop_subscriptions_done {
@@ -303,6 +342,7 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -328,6 +368,7 @@ impl ComputeNode {
concurrency_token: Arc<tokio::sync::Semaphore>,
db: DB,
subphases: Vec<PerDatabasePhase>,
lakebase_mode: bool,
) -> Result<()> {
let _permit = concurrency_token.acquire().await?;
@@ -355,6 +396,7 @@ impl ComputeNode {
let client = client_conn.as_ref().unwrap();
Ok(client)
},
lakebase_mode,
)
.await?;
}
@@ -477,6 +519,10 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -489,7 +535,14 @@ pub enum ApplySpecPhase {
DisablePostgresDBPgAudit,
HandleOtherExtensions,
HandleNeonExtension,
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
DropRoles,
FinalizeDropLogicalSubscriptions,
}
@@ -525,6 +578,7 @@ pub async fn apply_operations<'a, Fut, F>(
jwks_roles: Arc<HashSet<String>>,
apply_spec_phase: ApplySpecPhase,
client: F,
lakebase_mode: bool,
) -> Result<()>
where
F: FnOnce() -> Fut,
@@ -571,6 +625,23 @@ where
},
query
);
if !lakebase_mode {
return res;
}
// BEGIN HADRON
if let Err(e) = res.as_ref() {
if let Some(sql_state) = e.code() {
if sql_state.code() == "57014" {
// SQL State 57014 (ERRCODE_QUERY_CANCELED) is used for statement timeouts.
// Increment the counter whenever a statement timeout occurs. Timeouts on
// this configuration path can only occur due to PS connectivity problems that
// Postgres failed to recover from.
COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.inc();
}
}
}
// END HADRON
res
}
.instrument(inspan)
@@ -612,6 +683,35 @@ async fn get_operations<'a>(
),
comment: None,
}))),
// BEGIN_HADRON
// New Hadron phase
ApplySpecPhase::CreateDatabricksRoles => {
let queries = create_databricks_roles();
let operations = queries.into_iter().map(|query| Operation {
query,
comment: None,
});
Ok(Box::new(operations))
}
// Backfill existing databricks_reader_* roles with statement timeout from GUC
ApplySpecPhase::AlterDatabricksRoles => {
let query = String::from(include_str!(
"sql/alter_databricks_reader_roles_timeout.sql"
));
let operations = once(Operation {
query,
comment: Some(
"Backfill existing databricks_reader_* roles with statement timeout"
.to_string(),
),
});
Ok(Box::new(operations))
}
// End of new Hadron Phase
// END_HADRON
ApplySpecPhase::DropInvalidDatabases => {
let mut ctx = ctx.write().await;
let databases = &mut ctx.dbs;
@@ -981,7 +1081,10 @@ async fn get_operations<'a>(
// N.B. this has to be properly dollar-escaped with `pg_quote_dollar()`
role_name = escaped_role,
outer_tag = outer_tag,
),
)
// HADRON change:
.replace("neon_superuser", &params.privileged_role_name),
// HADRON change end ,
comment: None,
},
// This now will only drop privileges of the role
@@ -1017,7 +1120,8 @@ async fn get_operations<'a>(
comment: None,
},
Operation {
query: String::from(include_str!("sql/default_grants.sql")),
query: String::from(include_str!("sql/default_grants.sql"))
.replace("neon_superuser", &params.privileged_role_name),
comment: None,
},
]
@@ -1086,6 +1190,28 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// Note: we may want to version the extension someday, but for now we just drop it and recreate it.
ApplySpecPhase::HandleDatabricksAuthExtension => {
let operations = vec![
Operation {
query: String::from("DROP EXTENSION IF EXISTS databricks_auth"),
comment: Some(String::from("dropping existing databricks_auth extension")),
},
Operation {
query: String::from("CREATE EXTENSION databricks_auth"),
comment: Some(String::from("creating databricks_auth extension")),
},
Operation {
query: String::from("GRANT SELECT ON databricks_auth_metrics TO pg_monitor"),
comment: Some(String::from("grant select on databricks auth counters")),
},
]
.into_iter();
Ok(Box::new(operations))
}
// END_HADRON
ApplySpecPhase::CreateAvailabilityCheck => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/add_availabilitycheck_tables.sql")),
comment: None,
@@ -1103,6 +1229,63 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// New Hadron phases
//
// Grants permissions to roles that are used by Databricks.
ApplySpecPhase::AddDatabricksGrants => {
let operations = vec![
Operation {
query: String::from("GRANT USAGE ON SCHEMA neon TO databricks_monitor"),
comment: Some(String::from(
"Permissions needed to execute neon.* functions (in the postgres database)",
)),
},
Operation {
query: String::from(
"GRANT SELECT, INSERT, UPDATE ON health_check TO databricks_monitor",
),
comment: Some(String::from("Permissions needed for read and write probes")),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION pg_ls_dir(text) TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to monitor .snap file counts",
)),
},
Operation {
query: String::from(
"GRANT SELECT ON neon.neon_perf_counters TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to access neon performance counters view",
)),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION neon.get_perf_counters() TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to execute the underlying performance counters function",
)),
},
]
.into_iter();
Ok(Box::new(operations))
}
// Creates minor objects that are used by Databricks.
ApplySpecPhase::CreateDatabricksMisc => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/create_databricks_misc.sql")),
comment: Some(String::from(
"The function databricks_monitor uses to convert exception to 0 or 1",
)),
}))),
// End of new Hadron phases
// END_HADRON
ApplySpecPhase::FinalizeDropLogicalSubscriptions => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/finalize_drop_subscriptions.sql")),
comment: None,

View File

@@ -0,0 +1,25 @@
DO $$
DECLARE
reader_role RECORD;
timeout_value TEXT;
BEGIN
-- Get the current GUC setting for reader statement timeout
SELECT current_setting('databricks.reader_statement_timeout', true) INTO timeout_value;
-- Only proceed if timeout_value is not null/empty and not '0' (disabled)
IF timeout_value IS NOT NULL AND timeout_value != '' AND timeout_value != '0' THEN
-- Find all databricks_reader_* roles and update their statement_timeout
FOR reader_role IN
SELECT r.rolname
FROM pg_roles r
WHERE r.rolname ~ '^databricks_reader_\d+$'
LOOP
-- Apply the timeout setting to the role (will overwrite existing setting)
EXECUTE format('ALTER ROLE %I SET statement_timeout = %L',
reader_role.rolname, timeout_value);
RAISE LOG 'Updated statement_timeout = % for role %', timeout_value, reader_role.rolname;
END LOOP;
END IF;
END
$$;

View File

@@ -0,0 +1,15 @@
ALTER ROLE databricks_monitor SET statement_timeout = '60s';
CREATE OR REPLACE FUNCTION health_check_write_succeeds()
RETURNS INTEGER AS $$
BEGIN
INSERT INTO health_check VALUES (1, now())
ON CONFLICT (id) DO UPDATE
SET updated_at = now();
RETURN 1;
EXCEPTION WHEN OTHERS THEN
RAISE EXCEPTION '[DATABRICKS_SMGR] health_check failed: [%] %', SQLSTATE, SQLERRM;
RETURN 0;
END;
$$ LANGUAGE plpgsql;

View File

@@ -793,6 +793,7 @@ impl Endpoint {
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
suspend_timeout_seconds: -1, // Only used in neon_local.
databricks_settings: None,
};
// this strange code is needed to support respec() in tests

View File

@@ -193,6 +193,9 @@ pub struct ComputeSpec {
///
/// We use this value to derive other values, such as the installed extensions metric.
pub suspend_timeout_seconds: i64,
// Databricks specific options for compute instance.
pub databricks_settings: Option<DatabricksSettings>,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.

View File

@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
if let Some(request_id) = req_info.context::<RequestId>()
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
{
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
Ok(res)

View File

@@ -72,10 +72,10 @@ impl Server {
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
return true;
}
if let Some(inner) = err.source() {
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
return suppress_io_error(io);
}
if let Some(inner) = err.source()
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
{
return suppress_io_error(io);
}
false
}

View File

@@ -129,6 +129,12 @@ impl<L: LabelGroup> InfoMetric<L> {
}
}
impl<L: LabelGroup + Default> Default for InfoMetric<L, GaugeState> {
fn default() -> Self {
InfoMetric::new(L::default())
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {

View File

@@ -363,7 +363,7 @@ where
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;

View File

@@ -9,10 +9,7 @@ regex.workspace = true
bytes.workspace = true
anyhow.workspace = true
crc32c.workspace = true
criterion.workspace = true
once_cell.workspace = true
log.workspace = true
memoffset.workspace = true
pprof.workspace = true
thiserror.workspace = true
serde.workspace = true
@@ -22,6 +19,7 @@ tracing.workspace = true
postgres_versioninfo.workspace = true
[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
postgres.workspace = true

View File

@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
impl ControlFileData {
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
/// Equivalent to offsetof(ControlFileData, crc) in C.
// Someday this can be const when the right compiler features land.
fn pg_control_crc_offset() -> usize {
memoffset::offset_of!(ControlFileData, crc)
const fn pg_control_crc_offset() -> usize {
std::mem::offset_of!(ControlFileData, crc)
}
///

View File

@@ -4,12 +4,11 @@
use crate::pg_constants;
use crate::transaction_id_precedes;
use bytes::BytesMut;
use log::*;
use super::bindings::MultiXactId;
pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) {
trace!(
tracing::trace!(
"handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)",
status
);

View File

@@ -14,7 +14,6 @@ use super::xlog_utils::*;
use crate::WAL_SEGMENT_SIZE;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crc32c::*;
use log::*;
use std::cmp::min;
use std::num::NonZeroU32;
use utils::lsn::Lsn;
@@ -236,7 +235,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder {
// XLOG_SWITCH records are special. If we see one, we need to skip
// to the next WAL segment.
let next_lsn = if xlogrec.is_xlog_switch_record() {
trace!("saw xlog switch record at {}", self.lsn);
tracing::trace!("saw xlog switch record at {}", self.lsn);
self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64)
} else {
// Pad to an 8-byte boundary

View File

@@ -23,8 +23,6 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ};
use bytes::BytesMut;
use bytes::{Buf, Bytes};
use log::*;
use serde::Serialize;
use std::ffi::{CString, OsStr};
use std::fs::File;
@@ -235,7 +233,7 @@ pub fn find_end_of_wal(
let mut curr_lsn = start_lsn;
let mut buf = [0u8; XLOG_BLCKSZ];
let pg_version = MY_PGVERSION;
debug!("find_end_of_wal PG_VERSION: {}", pg_version);
tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version);
let mut decoder = WalStreamDecoder::new(start_lsn, pg_version);
@@ -247,7 +245,7 @@ pub fn find_end_of_wal(
match open_wal_segment(&seg_file_path)? {
None => {
// no more segments
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, segment {:?} doesn't exist",
result, seg_file_path
);
@@ -260,7 +258,7 @@ pub fn find_end_of_wal(
while curr_lsn.segment_number(wal_seg_size) == segno {
let bytes_read = segment.read(&mut buf)?;
if bytes_read == 0 {
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}",
result,
seg_file_path,
@@ -276,7 +274,7 @@ pub fn find_end_of_wal(
match decoder.poll_decode() {
Ok(Some(record)) => result = record.0,
Err(e) => {
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, decode error: {:?}",
result, e
);

View File

@@ -185,6 +185,7 @@ impl Client {
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
write_buf: BytesMut,
) -> Client {
Client {
inner: InnerClient {
@@ -195,7 +196,7 @@ impl Client {
waiting: 0,
received: 0,
},
buffer: Default::default(),
buffer: write_buf,
},
cached_typeinfo: Default::default(),

View File

@@ -47,14 +47,7 @@ impl Encoder<BytesMut> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> {
// When it comes to request/response workflows, we usually flush the entire write
// buffer in order to wait for the response before we send a new request.
// Therefore we can avoid the copy and just replace the buffer.
if dst.is_empty() {
*dst = item;
} else {
dst.extend_from_slice(&item);
}
dst.unsplit(item);
Ok(())
}
}

View File

@@ -77,6 +77,9 @@ where
connect_timeout,
};
let mut stream = stream.into_framed();
let write_buf = std::mem::take(stream.write_buffer_mut());
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
@@ -86,9 +89,9 @@ where
ssl_mode,
process_id,
secret_key,
write_buf,
);
let stream = stream.into_framed();
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))

View File

@@ -229,8 +229,11 @@ where
Poll::Ready(()) => {
trace!("poll_flush: flushed");
// GC the write buffer if we managed to flush
gc_bytesmut(self.stream.write_buffer_mut());
// Since our codec prefers to share the buffer with the `Client`,
// if we don't release our share, then the `Client` would have to re-alloc
// the buffer when they next use it.
debug_assert!(self.stream.write_buffer().is_empty());
*self.stream.write_buffer_mut() = BytesMut::new();
Poll::Ready(Ok(()))
}

View File

@@ -9,7 +9,7 @@ use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
pub use self::sqlstate::*;
#[allow(clippy::unreadable_literal)]
mod sqlstate;
pub mod sqlstate;
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]

View File

@@ -49,7 +49,7 @@ impl PerfSpan {
}
}
pub fn enter(&self) -> PerfSpanEntered {
pub fn enter(&self) -> PerfSpanEntered<'_> {
if let Some(ref id) = self.inner.id() {
self.dispatch.enter(id);
}

View File

@@ -14,9 +14,9 @@ use utils::logging::warn_slow;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_page_api as page_api;
use pageserver_page_api::GetPageSplitter;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};

View File

@@ -1,6 +1,5 @@
mod client;
mod pool;
mod retry;
mod split;
pub use client::{PageserverClient, ShardSpec};

View File

@@ -19,7 +19,9 @@ pub mod proto {
}
mod client;
pub use client::Client;
mod model;
mod split;
pub use client::Client;
pub use model::*;
pub use split::GetPageSplitter;

View File

@@ -3,18 +3,18 @@ use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use crate::model::*;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::key_to_shard_number;
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// Split requests by shard index.
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
requests: HashMap<ShardIndex, GetPageRequest>,
/// The response being assembled. Preallocated with empty pages, to be filled in.
response: page_api::GetPageResponse,
response: GetPageResponse,
/// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used
/// to assemble the response pages in the same order as the original request.
block_shards: Vec<ShardIndex>,
@@ -24,7 +24,7 @@ impl GetPageSplitter {
/// Checks if the given request only touches a single shard, and returns the shard ID. This is
/// the common case, so we check first in order to avoid unnecessary allocations and overhead.
pub fn for_single_shard(
req: &page_api::GetPageRequest,
req: &GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
@@ -57,7 +57,7 @@ impl GetPageSplitter {
/// Splits the given request.
pub fn split(
req: page_api::GetPageRequest,
req: GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
@@ -84,7 +84,7 @@ impl GetPageSplitter {
requests
.entry(shard_id)
.or_insert_with(|| page_api::GetPageRequest {
.or_insert_with(|| GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
@@ -98,16 +98,16 @@ impl GetPageSplitter {
// Construct a response to be populated by shard responses. Preallocate empty page slots
// with the expected block numbers.
let response = page_api::GetPageResponse {
let response = GetPageResponse {
request_id: req.request_id,
status_code: page_api::GetPageStatusCode::Ok,
status_code: GetPageStatusCode::Ok,
reason: None,
rel: req.rel,
pages: req
.block_numbers
.into_iter()
.map(|block_number| {
page_api::Page {
Page {
block_number,
image: Bytes::new(), // empty page slot to be filled in
}
@@ -123,9 +123,7 @@ impl GetPageSplitter {
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
pub fn drain_requests(&mut self) -> impl Iterator<Item = (ShardIndex, GetPageRequest)> {
self.requests.drain()
}
@@ -135,10 +133,10 @@ impl GetPageSplitter {
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: page_api::GetPageResponse,
response: GetPageResponse,
) -> anyhow::Result<()> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != page_api::GetPageStatusCode::Ok {
if response.status_code != GetPageStatusCode::Ok {
return Err(anyhow!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
@@ -209,7 +207,7 @@ impl GetPageSplitter {
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> anyhow::Result<page_api::GetPageResponse> {
pub fn get_response(self) -> anyhow::Result<GetPageResponse> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {

View File

@@ -715,7 +715,7 @@ fn start_pageserver(
disk_usage_eviction_state,
deletion_queue.new_client(),
secondary_controller,
feature_resolver,
feature_resolver.clone(),
)
.context("Failed to initialize router state")?,
);
@@ -841,14 +841,14 @@ fn start_pageserver(
} else {
None
},
feature_resolver.clone(),
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
// each stream/request.
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for each request/stream.
// It uses a separate compute request Tokio runtime (COMPUTE_REQUEST_RUNTIME).
//
// TODO: this uses a separate Tokio runtime for the page service. If we want
// other gRPC services, they will need their own port and runtime. Is this
// necessary?
// NB: this port is exposed to computes. It should only provide services that we're okay with
// computes accessing. Internal services should use a separate port.
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(GrpcPageServiceHandler::spawn(

View File

@@ -2005,6 +2005,10 @@ async fn put_tenant_location_config_handler(
let state = get_state(&request);
let conf = state.conf;
fail::fail_point!("put-location-conf-handler", |_| {
Err(ApiError::ResourceUnavailable("failpoint".into()))
});
// The `Detached` state is special, it doesn't upsert a tenant, it removes
// its local disk content and drops it from memory.
if let LocationConfigMode::Detached = request_data.config.mode {

View File

@@ -16,7 +16,8 @@ use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt as _};
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -35,8 +36,8 @@ use pageserver_api::pagestream_api::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api as page_api;
use pageserver_page_api::proto;
use pageserver_page_api::{self as page_api, GetPageSplitter};
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -68,6 +69,7 @@ use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
};
use crate::feature_resolver::FeatureResolver;
use crate::metrics::{
self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS,
MISROUTED_PAGESTREAM_REQUESTS, PAGESTREAM_HANDLER_RESULTS_TOTAL, SmgrOpTimer, TimelineMetrics,
@@ -139,6 +141,7 @@ pub fn spawn(
perf_trace_dispatch: Option<Dispatch>,
tcp_listener: tokio::net::TcpListener,
tls_config: Option<Arc<rustls::ServerConfig>>,
feature_resolver: FeatureResolver,
) -> Listener {
let cancel = CancellationToken::new();
let libpq_ctx = RequestContext::todo_child(
@@ -160,6 +163,7 @@ pub fn spawn(
conf.pg_auth_type,
tls_config,
conf.page_service_pipelining.clone(),
feature_resolver,
libpq_ctx,
cancel.clone(),
)
@@ -218,6 +222,7 @@ pub async fn libpq_listener_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
feature_resolver: FeatureResolver,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
) -> Connections {
@@ -261,6 +266,7 @@ pub async fn libpq_listener_main(
auth_type,
tls_config.clone(),
pipelining_config.clone(),
feature_resolver.clone(),
connection_ctx,
connections_cancel.child_token(),
gate_guard,
@@ -303,6 +309,7 @@ async fn page_service_conn_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
feature_resolver: FeatureResolver,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -370,6 +377,7 @@ async fn page_service_conn_main(
perf_span_fields,
connection_ctx,
cancel.clone(),
feature_resolver.clone(),
gate_guard,
);
let pgbackend =
@@ -421,6 +429,8 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
feature_resolver: FeatureResolver,
gate_guard: GateGuard,
}
@@ -457,13 +467,6 @@ impl TimelineHandles {
self.handles
.get(timeline_id, shard_selector, &self.wrapper)
.await
.map_err(|e| match e {
timeline::handle::GetError::TenantManager(e) => e,
timeline::handle::GetError::PerTimelineStateShutDown => {
trace!("per-timeline state shut down");
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
@@ -479,11 +482,9 @@ pub(crate) struct TenantManagerWrapper {
tenant_id: once_cell::sync::OnceCell<TenantId>,
}
#[derive(Debug)]
pub(crate) struct TenantManagerTypes;
impl timeline::handle::Types for TenantManagerTypes {
type TenantManagerError = GetActiveTimelineError;
type TenantManager = TenantManagerWrapper;
type Timeline = TenantManagerCacheItem;
}
@@ -587,6 +588,15 @@ impl timeline::handle::TenantManager<TenantManagerTypes> for TenantManagerWrappe
}
}
/// Whether to hold the applied GC cutoff guard when processing GetPage requests.
/// This is determined once at the start of pagestream subprotocol handling based on
/// feature flags, configuration, and test conditions.
#[derive(Debug, Clone, Copy)]
enum HoldAppliedGcCutoffGuard {
Yes,
No,
}
#[derive(thiserror::Error, Debug)]
enum PageStreamError {
/// We encountered an error that should prompt the client to reconnect:
@@ -730,6 +740,7 @@ enum BatchedFeMessage {
GetPage {
span: Span,
shard: WeakHandle<TenantManagerTypes>,
applied_gc_cutoff_guard: Option<RcuReadGuard<Lsn>>,
pages: SmallVec<[BatchedGetPageRequest; 1]>,
batch_break_reason: GetPageBatchBreakReason,
},
@@ -909,6 +920,7 @@ impl PageServerHandler {
perf_span_fields: ConnectionPerfSpanFields,
connection_ctx: RequestContext,
cancel: CancellationToken,
feature_resolver: FeatureResolver,
gate_guard: GateGuard,
) -> Self {
PageServerHandler {
@@ -920,6 +932,7 @@ impl PageServerHandler {
cancel,
pipelining_config,
get_vectored_concurrent_io,
feature_resolver,
gate_guard,
}
}
@@ -959,6 +972,7 @@ impl PageServerHandler {
ctx: &RequestContext,
protocol_version: PagestreamProtocolVersion,
parent_span: Span,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
) -> Result<Option<BatchedFeMessage>, QueryError>
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
@@ -1196,19 +1210,27 @@ impl PageServerHandler {
})
.await?;
let applied_gc_cutoff_guard = shard.get_applied_gc_cutoff_lsn(); // hold guard
// We're holding the Handle
let effective_lsn = match Self::effective_request_lsn(
&shard,
shard.get_last_record_lsn(),
req.hdr.request_lsn,
req.hdr.not_modified_since,
&shard.get_applied_gc_cutoff_lsn(),
&applied_gc_cutoff_guard,
) {
Ok(lsn) => lsn,
Err(e) => {
return respond_error!(span, e);
}
};
let applied_gc_cutoff_guard = match hold_gc_cutoff_guard {
HoldAppliedGcCutoffGuard::Yes => Some(applied_gc_cutoff_guard),
HoldAppliedGcCutoffGuard::No => {
drop(applied_gc_cutoff_guard);
None
}
};
let batch_wait_ctx = if ctx.has_perf_span() {
Some(
@@ -1229,6 +1251,7 @@ impl PageServerHandler {
BatchedFeMessage::GetPage {
span,
shard: shard.downgrade(),
applied_gc_cutoff_guard,
pages: smallvec![BatchedGetPageRequest {
req,
timer,
@@ -1329,13 +1352,28 @@ impl PageServerHandler {
match (eligible_batch, this_msg) {
(
BatchedFeMessage::GetPage {
pages: accum_pages, ..
pages: accum_pages,
applied_gc_cutoff_guard: accum_applied_gc_cutoff_guard,
..
},
BatchedFeMessage::GetPage {
pages: this_pages, ..
pages: this_pages,
applied_gc_cutoff_guard: this_applied_gc_cutoff_guard,
..
},
) => {
accum_pages.extend(this_pages);
// the minimum of the two guards will keep data for both alive
match (&accum_applied_gc_cutoff_guard, this_applied_gc_cutoff_guard) {
(None, None) => (),
(None, Some(this)) => *accum_applied_gc_cutoff_guard = Some(this),
(Some(_), None) => (),
(Some(accum), Some(this)) => {
if **accum > *this {
*accum_applied_gc_cutoff_guard = Some(this);
}
}
};
Ok(())
}
#[cfg(feature = "testing")]
@@ -1650,6 +1688,7 @@ impl PageServerHandler {
BatchedFeMessage::GetPage {
span,
shard,
applied_gc_cutoff_guard,
pages,
batch_break_reason,
} => {
@@ -1669,6 +1708,7 @@ impl PageServerHandler {
.instrument(span.clone())
.await;
assert_eq!(res.len(), npages);
drop(applied_gc_cutoff_guard);
res
},
span,
@@ -1750,7 +1790,7 @@ impl PageServerHandler {
/// Coding discipline within this function: all interaction with the `pgb` connection
/// needs to be sensitive to connection shutdown, currently signalled via [`Self::cancel`].
/// This is so that we can shutdown page_service quickly.
#[instrument(skip_all)]
#[instrument(skip_all, fields(hold_gc_cutoff_guard))]
async fn handle_pagerequests<IO>(
&mut self,
pgb: &mut PostgresBackend<IO>,
@@ -1796,6 +1836,30 @@ impl PageServerHandler {
.take()
.expect("implementation error: timeline_handles should not be locked");
// Evaluate the expensive feature resolver check once per pagestream subprotocol handling
// instead of once per GetPage request. This is shared between pipelined and serial paths.
let hold_gc_cutoff_guard = if cfg!(test) || cfg!(feature = "testing") {
HoldAppliedGcCutoffGuard::Yes
} else {
// Use the global feature resolver with the tenant ID directly, avoiding the need
// to get a timeline/shard which might not be available on this pageserver node.
let empty_properties = std::collections::HashMap::new();
match self.feature_resolver.evaluate_boolean(
"page-service-getpage-hold-applied-gc-cutoff-guard",
tenant_id,
&empty_properties,
) {
Ok(()) => HoldAppliedGcCutoffGuard::Yes,
Err(_) => HoldAppliedGcCutoffGuard::No,
}
};
// record it in the span of handle_pagerequests so that both the request_span
// and the pipeline implementation spans contains the field.
Span::current().record(
"hold_gc_cutoff_guard",
tracing::field::debug(&hold_gc_cutoff_guard),
);
let request_span = info_span!("request");
let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() {
PageServicePipeliningConfig::Pipelined(pipelining_config) => {
@@ -1809,6 +1873,7 @@ impl PageServerHandler {
pipelining_config,
protocol_version,
io_concurrency,
hold_gc_cutoff_guard,
&ctx,
)
.await
@@ -1823,6 +1888,7 @@ impl PageServerHandler {
request_span,
protocol_version,
io_concurrency,
hold_gc_cutoff_guard,
&ctx,
)
.await
@@ -1851,6 +1917,7 @@ impl PageServerHandler {
request_span: Span,
protocol_version: PagestreamProtocolVersion,
io_concurrency: IoConcurrency,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
ctx: &RequestContext,
) -> (
(PostgresBackendReader<IO>, TimelineHandles),
@@ -1872,6 +1939,7 @@ impl PageServerHandler {
ctx,
protocol_version,
request_span.clone(),
hold_gc_cutoff_guard,
)
.await;
let msg = match msg {
@@ -1919,6 +1987,7 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfigPipelined,
protocol_version: PagestreamProtocolVersion,
io_concurrency: IoConcurrency,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
ctx: &RequestContext,
) -> (
(PostgresBackendReader<IO>, TimelineHandles),
@@ -2022,6 +2091,7 @@ impl PageServerHandler {
&ctx,
protocol_version,
request_span.clone(),
hold_gc_cutoff_guard,
)
.await;
let Some(read_res) = read_res.transpose() else {
@@ -2068,6 +2138,7 @@ impl PageServerHandler {
pages,
span: _,
shard: _,
applied_gc_cutoff_guard: _,
batch_break_reason: _,
} = &mut batch
{
@@ -3353,18 +3424,6 @@ impl GrpcPageServiceHandler {
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]
fn ensure_shard_zero(timeline: &Handle<TenantManagerTypes>) -> Result<(), tonic::Status> {
match timeline.get_shard_index().shard_number.0 {
0 => Ok(()),
shard => Err(tonic::Status::invalid_argument(format!(
"request must execute on shard zero (is shard {shard})",
))),
}
}
/// Generates a PagestreamRequest header from a ReadLsn and request ID.
fn make_hdr(
read_lsn: page_api::ReadLsn,
@@ -3379,30 +3438,72 @@ impl GrpcPageServiceHandler {
}
}
/// Acquires a timeline handle for the given request.
/// Acquires a timeline handle for the given request. The shard index must match a local shard.
///
/// TODO: during shard splits, the compute may still be sending requests to the parent shard
/// until the entire split is committed and the compute is notified. Consider installing a
/// temporary shard router from the parent to the children while the split is in progress.
///
/// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage
/// the TimelineHandles lifecycle.
///
/// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid
/// the unnecessary overhead.
/// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`].
async fn get_request_timeline(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, GetActiveTimelineError> {
let ttid = *extract::<TenantTimelineId>(req);
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
let shard_selector = ShardSelector::Known(shard_index);
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
TimelineHandles::new(self.tenant_manager.clone())
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
}
/// Acquires a timeline handle for the given request, which must be for shard zero. Most
/// metadata requests are only valid on shard zero.
///
/// NB: during an ongoing shard split, the compute will keep talking to the parent shard until
/// the split is committed, but the parent shard may have been removed in the meanwhile. In that
/// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`].
///
/// TODO: revamp the split protocol to avoid this child routing.
async fn get_request_timeline_shard_zero(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, tonic::Status> {
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
if shard_index.shard_number.0 != 0 {
return Err(tonic::Status::invalid_argument(format!(
"request only valid on shard zero (requested shard {shard_index})",
)));
}
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
Ok(timeline) => Ok(timeline),
Err(err) => {
// We may be in the middle of a shard split. Try to find a child shard 0.
if let Ok(timeline) = handles
.get(tenant_id, timeline_id, ShardSelector::Zero)
.await
&& timeline.get_shard_index().shard_count > shard_index.shard_count
{
return Ok(timeline);
}
Err(err.into())
}
}
}
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
/// Only errors if the timeline is shutting down.
///
@@ -3429,32 +3530,37 @@ impl GrpcPageServiceHandler {
/// NB: errors returned from here are intercepted in get_pages(), and may be converted to a
/// GetPageResponse with an appropriate status code to avoid terminating the stream.
///
/// TODO: verify that the requested pages belong to this shard.
///
/// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send
/// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or
/// split them up in the client or server.
#[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))]
#[instrument(skip_all, fields(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
))]
async fn get_page(
ctx: &RequestContext,
timeline: &WeakHandle<TenantManagerTypes>,
req: proto::GetPageRequest,
timeline: Handle<TenantManagerTypes>,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
) -> Result<proto::GetPageResponse, tonic::Status> {
let received_at = Instant::now();
let timeline = timeline.upgrade()?;
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
let req = page_api::GetPageRequest::try_from(req)?;
span_record!(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
);
for &blkno in &req.block_numbers {
let shard = timeline.get_shard_identity();
let key = rel_block_to_key(req.rel, blkno);
if !shard.is_key_local(&key) {
return Err(tonic::Status::invalid_argument(format!(
"block {blkno} of relation {} requested on wrong shard {} (is on {})",
req.rel,
timeline.get_shard_index(),
ShardIndex::new(shard.get_shard_number(&key), shard.count),
)));
}
}
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard
let effective_lsn = PageServerHandler::effective_request_lsn(
@@ -3530,7 +3636,95 @@ impl GrpcPageServiceHandler {
};
}
Ok(resp.into())
Ok(resp)
}
/// Processes a GetPage request when there is a potential shard split in progress. We have to
/// reroute the request to any local child shards, and split batch requests that straddle
/// multiple child shards.
///
/// Parent shards are split and removed incrementally (there may be many parent shards when
/// splitting an already-sharded tenant), but the compute is only notified once the overall
/// split commits, which can take several minutes. In the meanwhile, the compute will be sending
/// requests to the parent shards.
///
/// TODO: add test infrastructure to provoke this situation frequently and for long periods of
/// time, to properly exercise it.
///
/// TODO: revamp the split protocol to avoid this, e.g.:
/// * Keep the parent shard until the split commits and the compute is notified.
/// * Notify the compute about each subsplit.
/// * Return an error that updates the compute's shard map.
#[instrument(skip_all)]
#[allow(clippy::too_many_arguments)]
async fn maybe_split_get_page(
ctx: &RequestContext,
handles: &mut TimelineHandles,
tenant_id: TenantId,
timeline_id: TimelineId,
parent: ShardIndex,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
// Check the first page to see if we have any child shards at all. Otherwise, the compute is
// just talking to the wrong Pageserver. If the parent has been split, the shard now owning
// the page must have a higher shard count.
let timeline = handles
.get(
tenant_id,
timeline_id,
ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])),
)
.await?;
let shard_id = timeline.get_shard_identity();
if shard_id.count <= parent.shard_count {
return Err(HandleUpgradeError::ShutDown.into()); // emulate original error
}
// Fast path: the request fits in a single shard.
if let Some(shard_index) =
GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?
{
// We got the shard ID from the first page, so these must be equal.
assert_eq!(shard_index.shard_number, shard_id.number);
assert_eq!(shard_index.shard_count, shard_id.count);
return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await;
}
// The request spans multiple shards; split it and dispatch parallel requests. All pages
// were originally in the parent shard, and during a split all children are local, so we
// expect to find local shards for all pages.
let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_index, shard_req) in splitter.drain_requests() {
let timeline = handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await?;
let future = Self::get_page(
ctx,
timeline,
shard_req,
io_concurrency.clone(),
received_at,
)
.map(move |result| result.map(|resp| (shard_index, resp)));
shard_requests.push(future);
}
while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_index, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
}
}
@@ -3559,11 +3753,10 @@ impl proto::PageService for GrpcPageServiceHandler {
// to be the sweet spot where throughput is saturated.
const CHUNK_SIZE: usize = 256 * 1024;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate the request and decorate the span.
Self::ensure_shard_zero(&timeline)?;
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
@@ -3679,11 +3872,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn);
@@ -3712,14 +3904,29 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// Extract the timeline from the request and check that it exists.
let ttid = *extract::<TenantTimelineId>(&req);
//
// NB: during shard splits, the compute may still send requests to the parent shard. We'll
// reroute requests to the child shards below, but we also detect the common cases here
// where either the shard exists or no shards exist at all. If we have a child shard, we
// can't acquire a weak handle because we don't know which child shard to use yet.
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(&req);
let shard_index = *extract::<ShardIndex>(&req);
let shard_selector = ShardSelector::Known(shard_index);
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?;
let timeline = match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
// The timeline shard exists. Keep a weak handle to reuse for each request.
Ok(timeline) => Some(timeline.downgrade()),
// The shard doesn't exist, but a child shard does. We'll reroute requests later.
Err(_) if self.tenant_manager.has_child_shard(tenant_id, shard_index) => None,
// Failed to fetch the timeline, and no child shard exists. Error out.
Err(err) => return Err(err.into()),
};
// Spawn an IoConcurrency sidecar, if enabled.
let gate_guard = self
@@ -3736,11 +3943,9 @@ impl proto::PageService for GrpcPageServiceHandler {
let mut reqs = req.into_inner();
let resps = async_stream::try_stream! {
let timeline = handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?
.downgrade();
loop {
// Wait for the next client request.
//
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
@@ -3753,16 +3958,44 @@ impl proto::PageService for GrpcPageServiceHandler {
Err(err) => Err(err),
},
}?;
let received_at = Instant::now();
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
// Process the request, using a closure to capture errors.
let process_request = async || {
let req = page_api::GetPageRequest::try_from(req)?;
// Fast path: use the pre-acquired timeline handle.
if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) {
return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at)
.instrument(span.clone()) // propagate request span
.await
}
// The timeline handle is stale. During shard splits, the compute may still be
// sending requests to the parent shard. Try to re-route requests to the child
// shards, and split any batch requests that straddle multiple child shards.
Self::maybe_split_get_page(
&ctx,
&mut handles,
tenant_id,
timeline_id,
shard_index,
req,
io_concurrency.clone(),
received_at,
)
.instrument(span.clone()) // propagate request span
.await;
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
.await
};
// Return the response. Convert per-request errors to GetPageResponses if
// appropriate, or terminate the stream with a tonic::Status.
yield match process_request().await {
Ok(resp) => resp.into(),
Err(status) => {
// Log the error, since ObservabilityLayer won't see stream errors.
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
@@ -3782,11 +4015,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
@@ -3819,11 +4051,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?;
span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn);
@@ -3853,6 +4084,10 @@ impl proto::PageService for GrpcPageServiceHandler {
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
// TODO: this won't work during shard splits, as the request is directed at a specific shard
// but the parent shard is removed before the split commits and the compute is notified
// (which can take several minutes for large tenants). That's also the case for the libpq
// implementation, so we keep the behavior for now.
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);

View File

@@ -826,6 +826,18 @@ impl TenantManager {
peek_slot.is_some()
}
/// Returns whether a local shard exists that's a child of the given tenant shard. Note that
/// this just checks for any shard with a larger shard count, and it may not be a direct child
/// of the given shard (their keyspace may not overlap).
pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
match &*self.tenants.read().unwrap() {
TenantsMap::Initializing => false,
TenantsMap::Open(slots) | TenantsMap::ShuttingDown(slots) => slots
.range(TenantShardId::tenant_range(tenant_id))
.any(|(tsid, _)| tsid.shard_count > shard_index.shard_count),
}
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))]
pub(crate) async fn upsert_location(
&self,
@@ -1522,6 +1534,13 @@ impl TenantManager {
self.resources.deletion_queue_client.flush_advisory();
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
//
// TODO: keeping the parent as InProgress while spawning the children causes read
// unavailability, as we can't acquire a new timeline handle for it (existing handles appear
// to still work though, even downgraded ones). The parent should be available for reads
// until the children are ready -- potentially until *all* subsplits across all parent
// shards are complete and the compute has been notified. See:
// <https://databricks.atlassian.net/browse/LKB-672>.
drop(tenant);
let mut parent_slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;

View File

@@ -70,7 +70,7 @@ use tracing::*;
use utils::generation::Generation;
use utils::guard_arc_swap::GuardArcSwap;
use utils::id::TimelineId;
use utils::logging::{MonitorSlowFutureCallback, monitor_slow_future};
use utils::logging::{MonitorSlowFutureCallback, log_slow, monitor_slow_future};
use utils::lsn::{AtomicLsn, Lsn, RecordLsn};
use utils::postgres_client::PostgresClientProtocol;
use utils::rate_limit::RateLimit;
@@ -6898,7 +6898,13 @@ impl Timeline {
write_guard.store_and_unlock(new_gc_cutoff)
};
waitlist.wait().await;
let waitlist_wait_fut = std::pin::pin!(waitlist.wait());
log_slow(
"applied_gc_cutoff waitlist wait",
Duration::from_secs(30),
waitlist_wait_fut,
)
.await;
info!("GC starting");

View File

@@ -224,11 +224,11 @@ use tracing::{instrument, trace};
use utils::id::TimelineId;
use utils::shard::{ShardIndex, ShardNumber};
use crate::tenant::mgr::ShardSelector;
use crate::page_service::GetActiveTimelineError;
use crate::tenant::GetTimelineError;
use crate::tenant::mgr::{GetActiveTenantError, ShardSelector};
/// The requirement for Debug is so that #[derive(Debug)] works in some places.
pub(crate) trait Types: Sized + std::fmt::Debug {
type TenantManagerError: Sized + std::fmt::Debug;
pub(crate) trait Types: Sized {
type TenantManager: TenantManager<Self> + Sized;
type Timeline: Timeline<Self> + Sized;
}
@@ -307,12 +307,11 @@ impl<T: Types> Default for PerTimelineState<T> {
/// Abstract view of [`crate::tenant::mgr`], for testability.
pub(crate) trait TenantManager<T: Types> {
/// Invoked by [`Cache::get`] to resolve a [`ShardTimelineId`] to a [`Types::Timeline`].
/// Errors are returned as [`GetError::TenantManager`].
async fn resolve(
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> Result<T::Timeline, T::TenantManagerError>;
) -> Result<T::Timeline, GetActiveTimelineError>;
}
/// Abstract view of an [`Arc<Timeline>`], for testability.
@@ -322,13 +321,6 @@ pub(crate) trait Timeline<T: Types> {
fn per_timeline_state(&self) -> &PerTimelineState<T>;
}
/// Errors returned by [`Cache::get`].
#[derive(Debug)]
pub(crate) enum GetError<T: Types> {
TenantManager(T::TenantManagerError),
PerTimelineStateShutDown,
}
/// Internal type used in [`Cache::get`].
enum RoutingResult<T: Types> {
FastPath(Handle<T>),
@@ -345,7 +337,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
) -> Result<Handle<T>, GetActiveTimelineError> {
const GET_MAX_RETRIES: usize = 10;
const RETRY_BACKOFF: Duration = Duration::from_millis(100);
let mut attempt = 0;
@@ -356,7 +348,11 @@ impl<T: Types> Cache<T> {
.await
{
Ok(handle) => return Ok(handle),
Err(e) => {
Err(
e @ GetActiveTimelineError::Tenant(GetActiveTenantError::WaitForActiveTimeout {
..
}),
) => {
// Retry on tenant manager error to handle tenant split more gracefully
if attempt < GET_MAX_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
@@ -370,6 +366,7 @@ impl<T: Types> Cache<T> {
return Err(e);
}
}
Err(err) => return Err(err),
}
}
}
@@ -388,7 +385,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
) -> Result<Handle<T>, GetActiveTimelineError> {
// terminates because when every iteration we remove an element from the map
let miss: ShardSelector = loop {
let routing_state = self.shard_routing(timeline_id, shard_selector);
@@ -468,60 +465,50 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
match tenant_manager.resolve(timeline_id, shard_selector).await {
Ok(timeline) => {
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc =
Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
{
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
match &mut *lock_guard {
Some(per_timeline_state) => {
let replaced =
per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
}
}
}
None => {
return Err(GetError::PerTimelineStateShutDown);
}
}
}
Ok(handle)
}
Err(e) => Err(GetError::TenantManager(e)),
) -> Result<Handle<T>, GetActiveTimelineError> {
let timeline = tenant_manager.resolve(timeline_id, shard_selector).await?;
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc = Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
let Some(per_timeline_state) = &mut *lock_guard else {
return Err(GetActiveTimelineError::Timeline(
GetTimelineError::ShuttingDown,
));
};
let replaced = per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
}
}
Ok(handle)
}
}
@@ -655,7 +642,8 @@ mod tests {
use pageserver_api::models::ShardParameters;
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::DEFAULT_STRIPE_SIZE;
use utils::shard::ShardCount;
use utils::id::TenantId;
use utils::shard::{ShardCount, TenantShardId};
use utils::sync::gate::GateGuard;
use super::*;
@@ -665,7 +653,6 @@ mod tests {
#[derive(Debug)]
struct TestTypes;
impl Types for TestTypes {
type TenantManagerError = anyhow::Error;
type TenantManager = StubManager;
type Timeline = Entered;
}
@@ -716,40 +703,48 @@ mod tests {
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> anyhow::Result<Entered> {
) -> Result<Entered, GetActiveTimelineError> {
fn enter_gate(
timeline: &StubTimeline,
) -> Result<Arc<GateGuard>, GetActiveTimelineError> {
Ok(Arc::new(timeline.gate.enter().map_err(|_| {
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
})?))
}
for timeline in &self.shards {
if timeline.id == timeline_id {
let enter_gate = || {
let gate_guard = timeline.gate.enter()?;
let gate_guard = Arc::new(gate_guard);
anyhow::Ok(gate_guard)
};
match &shard_selector {
ShardSelector::Zero if timeline.shard.is_shard_zero() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Zero => continue,
ShardSelector::Page(key) if timeline.shard.is_key_local(key) => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Page(_) => continue,
ShardSelector::Known(idx) if idx == &timeline.shard.shard_index() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Known(_) => continue,
}
}
}
anyhow::bail!("not found")
Err(GetActiveTimelineError::Timeline(
GetTimelineError::NotFound {
tenant_id: TenantShardId::unsharded(TenantId::from([0; 16])),
timeline_id,
},
))
}
}

View File

@@ -52,7 +52,7 @@ pub(crate) fn regenerate(
};
// Express a static value for how many shards we may schedule on one node
const MAX_SHARDS: u32 = 5000;
const MAX_SHARDS: u32 = 2500;
let mut doc = PageserverUtilization {
disk_usage_bytes: used,

View File

@@ -79,10 +79,6 @@
#include "access/xlogrecovery.h"
#endif
#if PG_VERSION_NUM < 160000
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \
neon_shard_log(shard_no, elvl, "Broken connection state: " message, \
##__VA_ARGS__)

View File

@@ -635,6 +635,11 @@ lfc_init(void)
NULL);
}
/*
* Dump a list of pages that are currently in the LFC
*
* This is used to get a snapshot that can be used to prewarm the LFC later.
*/
FileCacheState*
lfc_get_state(size_t max_entries)
{
@@ -1827,125 +1832,46 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
typedef struct
/*
* Return metrics about the LFC.
*
* The return format is a palloc'd array of LfcStatsEntrys. The size
* of the returned array is returned in *num_entries.
*/
LfcStatsEntry *
lfc_get_stats(size_t *num_entries)
{
TupleDesc tupdesc;
} NeonGetStatsCtx;
LfcStatsEntry *entries;
size_t n = 0;
#define NUM_NEON_GET_STATS_COLS 2
#define MAX_ENTRIES 10
entries = palloc(sizeof(LfcStatsEntry) * MAX_ENTRIES);
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
NeonGetStatsCtx *fctx;
MemoryContext oldcontext;
TupleDesc tupledesc;
Datum result;
HeapTuple tuple;
char const *key;
uint64 value = 0;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
entries[n++] = (LfcStatsEntry) {"file_cache_chunk_size_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_blocks_per_chunk : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_misses", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->misses : 0};
entries[n++] = (LfcStatsEntry) {"file_cache_hits", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->hits : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_writes", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->writes : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_size", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->size : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_evicted_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->evicted_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_limit", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->limit : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_chunks_pinned", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->pinned : 0 };
Assert(n <= MAX_ENTRIES);
#undef MAX_ENTRIES
if (SRF_IS_FIRSTCALL())
{
funcctx = SRF_FIRSTCALL_INIT();
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (NeonGetStatsCtx *) palloc(sizeof(NeonGetStatsCtx));
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(NUM_NEON_GET_STATS_COLS);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "lfc_key",
TEXTOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "lfc_value",
INT8OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
}
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = (NeonGetStatsCtx *) funcctx->user_fctx;
switch (funcctx->call_cntr)
{
case 0:
key = "file_cache_misses";
if (lfc_ctl)
value = lfc_ctl->misses;
break;
case 1:
key = "file_cache_hits";
if (lfc_ctl)
value = lfc_ctl->hits;
break;
case 2:
key = "file_cache_used";
if (lfc_ctl)
value = lfc_ctl->used;
break;
case 3:
key = "file_cache_writes";
if (lfc_ctl)
value = lfc_ctl->writes;
break;
case 4:
key = "file_cache_size";
if (lfc_ctl)
value = lfc_ctl->size;
break;
case 5:
key = "file_cache_used_pages";
if (lfc_ctl)
value = lfc_ctl->used_pages;
break;
case 6:
key = "file_cache_evicted_pages";
if (lfc_ctl)
value = lfc_ctl->evicted_pages;
break;
case 7:
key = "file_cache_limit";
if (lfc_ctl)
value = lfc_ctl->limit;
break;
case 8:
key = "file_cache_chunk_size_pages";
value = lfc_blocks_per_chunk;
break;
case 9:
key = "file_cache_chunks_pinned";
if (lfc_ctl)
value = lfc_ctl->pinned;
break;
default:
SRF_RETURN_DONE(funcctx);
}
values[0] = PointerGetDatum(cstring_to_text(key));
nulls[0] = false;
if (lfc_ctl)
{
nulls[1] = false;
values[1] = Int64GetDatum(value);
}
else
nulls[1] = true;
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
*num_entries = n;
return entries;
}
@@ -1953,193 +1879,86 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
* Function returning data from the local file cache
* relation node/tablespace/database/blocknum and access_counter
*/
PG_FUNCTION_INFO_V1(local_cache_pages);
/*
* Record structure holding the to be exposed cache data.
*/
typedef struct
LocalCachePagesRec *
lfc_local_cache_pages(size_t *num_entries)
{
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
size_t n_pages;
size_t n;
LocalCachePagesRec *result;
/*
* Function context for data persisting over repeated calls.
*/
typedef struct
{
TupleDesc tupdesc;
LocalCachePagesRec *record;
} LocalCachePagesContext;
#define NUM_LOCALCACHE_PAGES_ELEM 7
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
Datum result;
MemoryContext oldcontext;
LocalCachePagesContext *fctx; /* User function context. */
TupleDesc tupledesc;
TupleDesc expected_tupledesc;
HeapTuple tuple;
if (SRF_IS_FIRSTCALL())
if (!lfc_ctl)
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
*num_entries = 0;
return NULL;
}
funcctx = SRF_FIRSTCALL_INIT();
LWLockAcquire(lfc_lock, LW_SHARED);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (LocalCachePagesContext *) palloc(sizeof(LocalCachePagesContext));
/*
* To smoothly support upgrades from version 1.0 of this extension
* transparently handle the (non-)existence of the pinning_backends
* column. We unfortunately have to get the result type for that... -
* we can't use the result type determined by the function definition
* without potentially crashing when somebody uses the old (or even
* wrong) function definition though.
*/
if (get_call_result_type(fcinfo, NULL, &expected_tupledesc) != TYPEFUNC_COMPOSITE)
neon_log(ERROR, "return type must be a row type");
if (expected_tupledesc->natts != NUM_LOCALCACHE_PAGES_ELEM)
neon_log(ERROR, "incorrect number of output arguments");
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(expected_tupledesc->natts);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "pageoffs",
INT8OID, -1, 0);
#if PG_MAJORVERSION_NUM < 16
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenode",
OIDOID, -1, 0);
#else
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenumber",
OIDOID, -1, 0);
#endif
TupleDescInitEntry(tupledesc, (AttrNumber) 3, "reltablespace",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 4, "reldatabase",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 5, "relforknumber",
INT2OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 6, "relblocknumber",
INT8OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 7, "accesscount",
INT4OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
if (lfc_ctl)
/* Count the pages first */
n_pages = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
LWLockAcquire(lfc_lock, LW_SHARED);
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
if (LFC_ENABLED())
if (n_pages == 0)
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
result = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
/*
* Scan through all the cache entries, saving the relevant fields
* in the result structure.
*/
n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
if (GET_STATE(entry, i) == AVAILABLE)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
result[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
result[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
result[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
result[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
result[n].forknum = entry->key.forkNum;
result[n].blocknum = entry->key.blockNum + i;
result[n].accesscount = entry->access_count;
n += 1;
}
}
}
fctx->record = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
/* Set max calls and remember the user function context. */
funcctx->max_calls = n_pages;
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
if (n_pages != 0)
{
/*
* Scan through all the cache entries, saving the relevant fields
* in the fctx->record structure.
*/
uint32 n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}
Assert(n_pages == n);
}
if (lfc_ctl)
LWLockRelease(lfc_lock);
}
Assert(n_pages == n);
LWLockRelease(lfc_lock);
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = funcctx->user_fctx;
if (funcctx->call_cntr < funcctx->max_calls)
{
uint32 i = funcctx->call_cntr;
Datum values[NUM_LOCALCACHE_PAGES_ELEM];
bool nulls[NUM_LOCALCACHE_PAGES_ELEM] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) fctx->record[i].pageoffs);
values[1] = ObjectIdGetDatum(fctx->record[i].relfilenode);
values[2] = ObjectIdGetDatum(fctx->record[i].reltablespace);
values[3] = ObjectIdGetDatum(fctx->record[i].reldatabase);
values[4] = ObjectIdGetDatum(fctx->record[i].forknum);
values[5] = Int64GetDatum((int64) fctx->record[i].blocknum);
values[6] = Int32GetDatum(fctx->record[i].accesscount);
/* Build and return the tuple. */
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
}
else
SRF_RETURN_DONE(funcctx);
*num_entries = n_pages;
return result;
}
/*
* Internal implementation of the approximate_working_set_size_seconds()
* function.
@@ -2267,4 +2086,3 @@ get_prewarm_info(PG_FUNCTION_ARGS)
PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls)));
}

View File

@@ -47,6 +47,26 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk
extern FileCacheState* lfc_get_state(size_t max_entries);
extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers);
typedef struct LfcStatsEntry
{
const char *metric_name;
bool isnull;
uint64 value;
} LfcStatsEntry;
extern LfcStatsEntry *lfc_get_stats(size_t *num_entries);
typedef struct
{
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
extern LocalCachePagesRec *lfc_local_cache_pages(size_t *num_entries);
extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset);

View File

@@ -1,7 +1,7 @@
/*-------------------------------------------------------------------------
*
* neon.c
* Main entry point into the neon exension
* Main entry point into the neon extension
*
*-------------------------------------------------------------------------
*/
@@ -508,7 +508,7 @@ _PG_init(void)
DefineCustomBoolVariable(
"neon.disable_logical_replication_subscribers",
"Disables incomming logical replication",
"Disable incoming logical replication",
NULL,
&disable_logical_replication_subscribers,
false,
@@ -567,7 +567,7 @@ _PG_init(void)
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
"Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
@@ -625,11 +625,15 @@ _PG_init(void)
ExecutorEnd_hook = neon_ExecutorEnd;
}
/* Various functions exposed at SQL level */
PG_FUNCTION_INFO_V1(pg_cluster_size);
PG_FUNCTION_INFO_V1(backpressure_lsns);
PG_FUNCTION_INFO_V1(backpressure_throttling_time);
PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds);
PG_FUNCTION_INFO_V1(approximate_working_set_size);
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
PG_FUNCTION_INFO_V1(local_cache_pages);
Datum
pg_cluster_size(PG_FUNCTION_ARGS)
@@ -704,6 +708,76 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc);
}
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
#define NUM_NEON_GET_STATS_COLS 2
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LfcStatsEntry *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_get_stats() does all the heavy lifting */
entries = lfc_get_stats(&num_entries);
/* Convert the LfcStatsEntrys to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LfcStatsEntry *entry = &entries[i];
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
values[0] = CStringGetTextDatum(entry->metric_name);
nulls[0] = false;
values[1] = Int64GetDatum(entry->isnull ? 0 : entry->value);
nulls[1] = entry->isnull;
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_NEON_GET_STATS_COLS
}
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
#define NUM_LOCALCACHE_PAGES_COLS 7
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LocalCachePagesRec *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_local_cache_pages() does all the heavy lifting */
entries = lfc_local_cache_pages(&num_entries);
/* Convert the LocalCachePagesRec structs to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LocalCachePagesRec *entry = &entries[i];
Datum values[NUM_LOCALCACHE_PAGES_COLS];
bool nulls[NUM_LOCALCACHE_PAGES_COLS] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) entry->pageoffs);
values[1] = ObjectIdGetDatum(entry->relfilenode);
values[2] = ObjectIdGetDatum(entry->reltablespace);
values[3] = ObjectIdGetDatum(entry->reldatabase);
values[4] = ObjectIdGetDatum(entry->forknum);
values[5] = Int64GetDatum((int64) entry->blocknum);
values[6] = Int32GetDatum(entry->accesscount);
/* Build and return the tuple. */
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_LOCALCACHE_PAGES_COLS
}
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.
@@ -735,7 +809,6 @@ neon_shmem_request_hook(void)
static void
neon_shmem_startup_hook(void)
{
/* Initialize */
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();

View File

@@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
*/
#define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS)
#if PG_VERSION_NUM >= 170000
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
#else
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno])
#endif
extern void inc_getpage_wait(uint64 latency);
extern void inc_page_cache_read_wait(uint64 latency);

View File

@@ -9,6 +9,10 @@
#include "fmgr.h"
#include "storage/buf_internals.h"
#if PG_MAJORVERSION_NUM < 16
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#if PG_MAJORVERSION_NUM < 17
#define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId)
#else
@@ -158,6 +162,10 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode,
#define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess())
#endif
#if PG_MAJORVERSION_NUM < 17
#define MyProcNumber (MyProc - &ProcGlobal->allProcs[0])
#endif
#if PG_MAJORVERSION_NUM < 15
extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags);
extern TimeLineID GetWALInsertionTimeLine(void);

View File

@@ -72,10 +72,6 @@
#include "access/xlogrecovery.h"
#endif
#if PG_VERSION_NUM < 160000
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"

View File

@@ -13,6 +13,7 @@
#include "neon.h"
#include "neon_pgversioncompat.h"
#include "miscadmin.h"
#include "pagestore_client.h"
#include RELFILEINFO_HDR
#include "storage/smgr.h"
@@ -23,10 +24,6 @@
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
#include "miscadmin.h"
#endif
typedef struct
{
NRelFileInfo rinfo;

View File

@@ -33,7 +33,6 @@ env_logger.workspace = true
framed-websockets.workspace = true
futures.workspace = true
hashbrown.workspace = true
hashlink.workspace = true
hex.workspace = true
hmac.workspace = true
hostname.workspace = true
@@ -54,6 +53,7 @@ json = { path = "../libs/proxy/json" }
lasso = { workspace = true, features = ["multi-threaded"] }
measured = { workspace = true, features = ["lasso"] }
metrics.workspace = true
moka.workspace = true
once_cell.workspace = true
opentelemetry = { workspace = true, features = ["trace"] }
papaya = "0.2.0"
@@ -107,10 +107,11 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
# uncomment this to use the real subzero-core crate
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
# this is a stub for the subzero-core crate
subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true}
subzero-core = { path = "../libs/proxy/subzero_core", features = ["postgresql"], optional = true}
ouroboros = { version = "0.18", optional = true }
# jwt stuff

View File

@@ -8,11 +8,12 @@ use tracing::{info, info_span};
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::control_plane::{self, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;

View File

@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl;
use crate::stream::{self, Stream};
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
role,
pool: config.scram_thread_pool.clone(),
},
);
let auth_outcome = {

View File

@@ -16,16 +16,16 @@ use tracing::{debug, info};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
.await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -433,11 +435,12 @@ mod tests {
use super::auth_quirks;
use super::jwt::JwkCache;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::cache::node_info::CachedNodeInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
@@ -498,7 +501,7 @@ mod tests {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) role: RoleNameInt,
pub(crate) secret: AuthSecret,
}
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
self.state.role,
password,
self.state.secret,
)
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -29,7 +29,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");
debug!("Build_tag: {BUILD_TAG}");
@@ -207,6 +205,11 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter,
);
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
// exit immediately on maintenance task completion
Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {},
@@ -279,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -26,7 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args
.get_one::<String>("dest")
@@ -135,6 +133,12 @@ pub async fn run() -> anyhow::Result<()> {
cancellation_token.clone(),
))
.map(crate::error::flatten_err);
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
// the signal task cant ever succeed.

View File

@@ -40,7 +40,7 @@ use crate::config::{
};
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
@@ -535,12 +535,7 @@ pub async fn run() -> anyhow::Result<()> {
// add a task to flush the db_schema cache every 10 minutes
#[cfg(feature = "rest_broker")]
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
maintenance_tasks.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(600)).await;
db_schema_cache.flush();
}
});
maintenance_tasks.spawn(db_schema_cache.maintain());
}
if let Some(metrics_config) = &config.metric_collection {
@@ -590,6 +585,11 @@ pub async fn run() -> anyhow::Result<()> {
}
}
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
let maintenance = loop {
// get one complete task
match futures::future::select(
@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
Metrics::get()
.proxy
.scram_pool
.0
.set(thread_pool.metrics.clone())
.ok();
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_thread_pool: thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
@@ -711,12 +716,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
info!("Using DbSchemaCache with options={db_schema_cache_config:?}");
let db_schema_cache = if args.is_rest_broker {
Some(DbSchemaCache::new(
"db_schema_cache",
db_schema_cache_config.size,
db_schema_cache_config.ttl,
true,
))
Some(DbSchemaCache::new(db_schema_cache_config))
} else {
None
};

View File

@@ -1,4 +1,16 @@
use std::ops::{Deref, DerefMut};
use std::time::{Duration, Instant};
use moka::Expiry;
use moka::notification::RemovalCause;
use crate::control_plane::messages::ControlPlaneErrorMessage;
use crate::metrics::{
CacheEviction, CacheKind, CacheOutcome, CacheOutcomeGroup, CacheRemovalCause, Metrics,
};
/// Default TTL used when caching errors from control plane.
pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30);
/// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation.
@@ -10,20 +22,16 @@ pub(crate) trait Cache {
/// Entry's value.
type Value;
/// Used for entry invalidation.
type LookupInfo<Key>;
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
fn invalidate(&self, _: &Self::Key);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
fn invalidate(&self, info: &Self::Key) {
C::invalidate(self, info);
}
}
@@ -31,7 +39,7 @@ impl<C: Cache> Cache for &C {
/// Wrapper for convenient entry invalidation.
pub(crate) struct Cached<C: Cache, V = <C as Cache>::Value> {
/// Cache + lookup info.
pub(crate) token: Option<(C, C::LookupInfo<C::Key>)>,
pub(crate) token: Option<(C, C::Key)>,
/// The value itself.
pub(crate) value: V,
@@ -43,23 +51,6 @@ impl<C: Cache, V> Cached<C, V> {
Self { token: None, value }
}
pub(crate) fn take_value(self) -> (Cached<C, ()>, V) {
(
Cached {
token: self.token,
value: (),
},
self.value,
)
}
pub(crate) fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
Cached {
token: self.token,
value: f(self.value),
}
}
/// Drop this entry from a cache if it's still there.
pub(crate) fn invalidate(self) -> V {
if let Some((cache, info)) = &self.token {
@@ -87,3 +78,91 @@ impl<C: Cache, V> DerefMut for Cached<C, V> {
&mut self.value
}
}
pub type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
#[derive(Clone, Copy)]
pub struct CplaneExpiry {
pub error: Duration,
}
impl Default for CplaneExpiry {
fn default() -> Self {
Self {
error: DEFAULT_ERROR_TTL,
}
}
}
impl CplaneExpiry {
pub fn expire_early<V>(
&self,
value: &ControlPlaneResult<V>,
updated: Instant,
) -> Option<Duration> {
match value {
Ok(_) => None,
Err(err) => Some(self.expire_err_early(err, updated)),
}
}
pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration {
err.status
.as_ref()
.and_then(|s| s.details.retry_info.as_ref())
.map_or(self.error, |r| r.retry_at.into_std() - updated)
}
}
impl<K, V> Expiry<K, ControlPlaneResult<V>> for CplaneExpiry {
fn expire_after_create(
&self,
_key: &K,
value: &ControlPlaneResult<V>,
created_at: Instant,
) -> Option<Duration> {
self.expire_early(value, created_at)
}
fn expire_after_update(
&self,
_key: &K,
value: &ControlPlaneResult<V>,
updated_at: Instant,
_duration_until_expiry: Option<Duration>,
) -> Option<Duration> {
self.expire_early(value, updated_at)
}
}
pub fn eviction_listener(kind: CacheKind, cause: RemovalCause) {
let cause = match cause {
RemovalCause::Expired => CacheRemovalCause::Expired,
RemovalCause::Explicit => CacheRemovalCause::Explicit,
RemovalCause::Replaced => CacheRemovalCause::Replaced,
RemovalCause::Size => CacheRemovalCause::Size,
};
Metrics::get()
.cache
.evicted_total
.inc(CacheEviction { cache: kind, cause });
}
#[inline]
pub fn count_cache_outcome<T>(kind: CacheKind, cache_result: Option<T>) -> Option<T> {
let outcome = if cache_result.is_some() {
CacheOutcome::Hit
} else {
CacheOutcome::Miss
};
Metrics::get().cache.request_total.inc(CacheOutcomeGroup {
cache: kind,
outcome,
});
cache_result
}
#[inline]
pub fn count_cache_insert(kind: CacheKind) {
Metrics::get().cache.inserted_total.inc(kind);
}

View File

@@ -1,6 +1,5 @@
pub(crate) mod common;
pub(crate) mod node_info;
pub(crate) mod project_info;
mod timed_lru;
pub(crate) use common::{Cache, Cached};
pub(crate) use timed_lru::TimedLru;
pub(crate) use common::{Cached, ControlPlaneResult, CplaneExpiry};

60
proxy/src/cache/node_info.rs vendored Normal file
View File

@@ -0,0 +1,60 @@
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::cache::{Cached, ControlPlaneResult, CplaneExpiry};
use crate::config::CacheOptions;
use crate::control_plane::NodeInfo;
use crate::metrics::{CacheKind, Metrics};
use crate::types::EndpointCacheKey;
pub(crate) struct NodeInfoCache(moka::sync::Cache<EndpointCacheKey, ControlPlaneResult<NodeInfo>>);
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
impl Cache for NodeInfoCache {
type Key = EndpointCacheKey;
type Value = ControlPlaneResult<NodeInfo>;
fn invalidate(&self, info: &EndpointCacheKey) {
self.0.invalidate(info);
}
}
impl NodeInfoCache {
pub fn new(config: CacheOptions) -> Self {
let builder = moka::sync::Cache::builder()
.name("node_info")
.expire_after(CplaneExpiry::default());
let builder = config.moka(builder);
if let Some(size) = config.size {
Metrics::get()
.cache
.capacity
.set(CacheKind::NodeInfo, size as i64);
}
let builder = builder
.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::NodeInfo, cause));
Self(builder.build())
}
pub fn insert(&self, key: EndpointCacheKey, value: ControlPlaneResult<NodeInfo>) {
count_cache_insert(CacheKind::NodeInfo);
self.0.insert(key, value);
}
pub fn get(&self, key: &EndpointCacheKey) -> Option<ControlPlaneResult<NodeInfo>> {
count_cache_outcome(CacheKind::NodeInfo, self.0.get(key))
}
pub fn get_entry(
&'static self,
key: &EndpointCacheKey,
) -> Option<ControlPlaneResult<CachedNodeInfo>> {
self.get(key).map(|res| {
res.map(|value| Cached {
token: Some((self, key.clone())),
value,
})
})
}
}

View File

@@ -1,84 +1,20 @@
use std::collections::{HashMap, HashSet, hash_map};
use std::collections::HashSet;
use std::convert::Infallible;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::Rng;
use tokio::time::Instant;
use moka::sync::Cache;
use tracing::{debug, info};
use crate::cache::common::{
ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener,
};
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
}
struct Entry<T> {
expires_at: Instant,
value: T,
}
impl<T> Entry<T> {
pub(crate) fn new(value: T, ttl: Duration) -> Self {
Self {
expires_at: Instant::now() + ttl,
value,
}
}
pub(crate) fn get(&self) -> Option<&T> {
(!self.is_expired()).then_some(&self.value)
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
}
}
struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
}
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo {
pub(crate) fn get_role_secret_with_ttl(
&self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn get_controls_with_ttl(
&self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn invalidate_endpoint(&mut self) {
self.controls = None;
}
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.role_controls.remove(&role_name);
}
}
/// Cache for project info.
/// This is used to cache auth data for endpoints.
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
@@ -86,8 +22,9 @@ impl EndpointInfo {
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
pub struct ProjectInfoCacheImpl {
cache: ClashMap<EndpointIdInt, EndpointInfo>,
pub struct ProjectInfoCache {
role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
@@ -96,16 +33,13 @@ pub struct ProjectInfoCacheImpl {
config: ProjectInfoCacheOptions,
}
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
impl ProjectInfoCache {
pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
info!("invalidating endpoint access for `{endpoint_id}`");
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self
.project2ep
@@ -113,13 +47,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
}
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
info!("invalidating endpoint access for org `{account_id}`");
let endpoints = self
.account2ep
@@ -127,13 +59,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
}
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
pub fn invalidate_role_secret_for_project(
&self,
project_id: ProjectIdInt,
role_name: RoleNameInt,
) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
project_id, role_name,
@@ -144,47 +78,73 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_role_secret(role_name);
}
self.role_controls.invalidate(&(endpoint_id, role_name));
}
}
}
impl ProjectInfoCacheImpl {
impl ProjectInfoCache {
pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
Metrics::get().cache.capacity.set(
CacheKind::ProjectInfoRoles,
(config.size * config.max_roles) as i64,
);
Metrics::get()
.cache
.capacity
.set(CacheKind::ProjectInfoEndpoints, config.size as i64);
// we cache errors for 30 seconds, unless retry_at is set.
let expiry = CplaneExpiry::default();
Self {
cache: ClashMap::new(),
role_controls: Cache::builder()
.name("project_info_roles")
.eviction_listener(|_k, _v, cause| {
eviction_listener(CacheKind::ProjectInfoRoles, cause);
})
.max_capacity(config.size * config.max_roles)
.time_to_live(config.ttl)
.expire_after(expiry)
.build(),
ep_controls: Cache::builder()
.name("project_info_endpoints")
.eviction_listener(|_k, _v, cause| {
eviction_listener(CacheKind::ProjectInfoEndpoints, cause);
})
.max_capacity(config.size)
.time_to_live(config.ttl)
.expire_after(expiry)
.build(),
project2ep: ClashMap::new(),
account2ep: ClashMap::new(),
config,
}
}
fn get_endpoint_cache(
&self,
endpoint_id: &EndpointId,
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret_with_ttl(
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
) -> Option<ControlPlaneResult<RoleAccessControl>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret_with_ttl(role_name)
count_cache_outcome(
CacheKind::ProjectInfoRoles,
self.role_controls.get(&(endpoint_id, role_name)),
)
}
pub(crate) fn get_endpoint_access_with_ttl(
pub(crate) fn get_endpoint_access(
&self,
endpoint_id: &EndpointId,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls_with_ttl()
) -> Option<ControlPlaneResult<EndpointAccessControl>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
count_cache_outcome(
CacheKind::ProjectInfoEndpoints,
self.ep_controls.get(&endpoint_id),
)
}
pub(crate) fn insert_endpoint_access(
@@ -203,34 +163,17 @@ impl ProjectInfoCacheImpl {
self.insert_project2endpoint(project_id, endpoint_id);
}
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
count_cache_insert(CacheKind::ProjectInfoEndpoints);
count_cache_insert(CacheKind::ProjectInfoRoles);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
self.ep_controls.insert(endpoint_id, Ok(controls));
self.role_controls
.insert((endpoint_id, role_name), Ok(role_controls));
}
pub(crate) fn insert_endpoint_access_err(
@@ -238,55 +181,34 @@ impl ProjectInfoCacheImpl {
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
if msg.get_reason() != Reason::RoleProtected {
// We can cache all the other errors in ep_controls because they don't
// depend on what role name we pass to control plane.
self.ep_controls
.entry(endpoint_id)
.and_compute_with(|entry| match entry {
// leave the entry alone if it's already Ok
Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
// replace the entry
_ => {
count_cache_insert(CacheKind::ProjectInfoEndpoints);
moka::ops::compute::Op::Put(Err(msg.clone()))
}
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
count_cache_insert(CacheKind::ProjectInfoRoles);
self.role_controls
.insert((endpoint_id, role_name), Err(msg));
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
@@ -307,73 +229,35 @@ impl ProjectInfoCacheImpl {
}
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
};
let Some(role_name) = RoleNameInt::get(role_name) else {
return;
};
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
return;
};
let entry = endpoint_info.role_controls.entry(role_name);
let hash_map::Entry::Occupied(role_controls) = entry else {
return;
};
if role_controls.get().is_expired() {
role_controls.remove();
}
pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
// TODO: Expire the value early if the key is idle.
// Currently not an issue as we would just use the TTL to decide, which is what already happens.
}
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
let mut interval =
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
let mut interval = tokio::time::interval(self.config.gc_interval);
loop {
interval.tick().await;
if self.cache.len() < self.config.size {
// If there are not too many entries, wait until the next gc cycle.
continue;
}
self.gc();
self.ep_controls.run_pending_tasks();
self.role_controls.run_pending_tasks();
}
}
fn gc(&self) {
let shard = rand::rng().random_range(0..self.project2ep.shards().len());
debug!(shard, "project_info_cache: performing epoch reclamation");
// acquire a random shard lock
let mut removed = 0;
let shard = self.project2ep.shards()[shard].write();
for (_, endpoints) in shard.iter() {
for endpoint in endpoints {
self.cache.remove(endpoint);
removed += 1;
}
}
// We can drop this shard only after making sure that all endpoints are removed.
drop(shard);
info!("project_info_cache: removed {removed} endpoints");
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use super::*;
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use std::sync::Arc;
#[tokio::test]
async fn test_project_info_cache_settings() {
tokio::time::pause();
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
size: 1,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
@@ -423,22 +307,17 @@ mod tests {
},
);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.role_controls.run_pending_tasks();
cache.insert_endpoint_access(
account_id,
project_id,
@@ -455,31 +334,18 @@ mod tests {
},
);
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
cache.role_controls.run_pending_tasks();
assert_eq!(cache.role_controls.entry_count(), 2);
let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::sleep(Duration::from_secs(2)).await;
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
assert!(cached.is_none());
cache.role_controls.run_pending_tasks();
assert_eq!(cache.role_controls.entry_count(), 0);
}
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
@@ -519,34 +385,23 @@ mod tests {
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
let get_role_secret =
|endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
assert!(cache.get_endpoint_access(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
@@ -558,11 +413,7 @@ mod tests {
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
@@ -590,7 +441,6 @@ mod tests {
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());

View File

@@ -1,262 +0,0 @@
use std::borrow::Borrow;
use std::hash::Hash;
use std::time::{Duration, Instant};
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{LruCache, linked_hash_map::RawEntryMut};
use tracing::debug;
use super::Cache;
use super::common::Cached;
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub(crate) struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
update_ttl_on_retrieval: bool,
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = Key;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info);
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
ttl: Duration,
update_ttl_on_retrieval: bool,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub(crate) fn new(
name: &'static str,
capacity: usize,
ttl: Duration,
update_ttl_on_retrieval: bool,
) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
update_ttl_on_retrieval,
}
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, key: &K) {
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x.remove(),
};
drop(cache); // drop lock before logging
let Entry {
created_at,
expires_at,
..
} = entry;
debug!(
?created_at,
?expires_at,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow");
if raw_entry.get().update_ttl_on_retrieval {
raw_entry.get_mut().expires_at = deadline;
}
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw_ttl(
&self,
key: K,
value: V,
ttl: Duration,
update: bool,
) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
ttl,
update_ttl_on_retrieval: update,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
self.insert_raw_ttl(key, value, ttl, false);
}
#[cfg(feature = "rest_broker")]
pub(crate) fn insert(&self, key: K, value: V) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval);
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (_, old) = self.insert_raw(key.clone(), value);
let cached = Cached {
token: Some((self, key)),
value: (),
};
(old, cached)
}
#[cfg(feature = "rest_broker")]
pub(crate) fn flush(&self) {
let now = Instant::now();
let mut cache = self.cache.lock();
// Collect keys of expired entries first
let expired_keys: Vec<_> = cache
.iter()
.filter_map(|(key, entry)| {
if entry.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
// Remove expired entries
for key in expired_keys {
cache.remove(&key);
}
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper, alongside timing information.
pub(crate) fn get_with_created_at<Q>(
&self,
key: &Q,
) -> Option<Cached<&Self, (<Self as Cache>::Value, Instant)>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| Cached {
token: Some((self, key.clone())),
value: (entry.value.clone(), entry.created_at),
})
}
}

View File

@@ -8,6 +8,7 @@ use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
use postgres_client::connect_raw::StartupStream;
use postgres_client::error::SqlState;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use thiserror::Error;
@@ -22,7 +23,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::connect_compute::TlsNegotiation;
@@ -65,12 +66,13 @@ impl UserFacingError for PostgresError {
}
impl ReportableError for PostgresError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
fn get_error_kind(&self) -> ErrorKind {
match self {
PostgresError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
PostgresError::Postgres(err) => match err.as_db_error() {
Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
Some(_) => ErrorKind::Postgres,
None => ErrorKind::Compute,
},
}
}
}
@@ -110,9 +112,9 @@ impl UserFacingError for ConnectionError {
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
fn get_error_kind(&self) -> ErrorKind {
match self {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]

View File

@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::scram;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
@@ -75,7 +75,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,
@@ -107,20 +107,23 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig>
#[derive(Debug)]
pub struct CacheOptions {
/// Max number of entries.
pub size: usize,
pub size: Option<u64>,
/// Entry's time-to-live.
pub ttl: Duration,
pub absolute_ttl: Option<Duration>,
/// Entry's time-to-idle.
pub idle_ttl: Option<Duration>,
}
impl CacheOptions {
/// Default options for [`crate::control_plane::NodeInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
/// Default options for [`crate::cache::node_info::NodeInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,idle_ttl=4m";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None;
let mut ttl = None;
let mut absolute_ttl = None;
let mut idle_ttl = None;
for option in options.split(',') {
let (key, value) = option
@@ -129,21 +132,34 @@ impl CacheOptions {
match key {
"size" => size = Some(value.parse()?),
"ttl" => ttl = Some(humantime::parse_duration(value)?),
"absolute_ttl" | "ttl" => absolute_ttl = Some(humantime::parse_duration(value)?),
"idle_ttl" | "tti" => idle_ttl = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// TTL doesn't matter if cache is always empty.
if let Some(0) = size {
ttl.get_or_insert(Duration::default());
}
Ok(Self {
size: size.context("missing `size`")?,
ttl: ttl.context("missing `ttl`")?,
size,
absolute_ttl,
idle_ttl,
})
}
pub fn moka<K, V, C>(
&self,
mut builder: moka::sync::CacheBuilder<K, V, C>,
) -> moka::sync::CacheBuilder<K, V, C> {
if let Some(size) = self.size {
builder = builder.max_capacity(size);
}
if let Some(ttl) = self.absolute_ttl {
builder = builder.time_to_live(ttl);
}
if let Some(tti) = self.idle_ttl {
builder = builder.time_to_idle(tti);
}
builder
}
}
impl FromStr for CacheOptions {
@@ -159,17 +175,17 @@ impl FromStr for CacheOptions {
#[derive(Debug)]
pub struct ProjectInfoCacheOptions {
/// Max number of entries.
pub size: usize,
pub size: u64,
/// Entry's time-to-live.
pub ttl: Duration,
/// Max number of roles per endpoint.
pub max_roles: usize,
pub max_roles: u64,
/// Gc interval.
pub gc_interval: Duration,
}
impl ProjectInfoCacheOptions {
/// Default options for [`crate::control_plane::NodeInfoCache`].
/// Default options for [`crate::cache::project_info::ProjectInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str =
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
@@ -496,21 +512,37 @@ mod tests {
#[test]
fn test_parse_cache_options() -> anyhow::Result<()> {
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
assert_eq!(size, 4096);
assert_eq!(ttl, Duration::from_secs(5 * 60));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=4096,ttl=5min".parse()?;
assert_eq!(size, Some(4096));
assert_eq!(absolute_ttl, Some(Duration::from_secs(5 * 60)));
let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
assert_eq!(size, 2);
assert_eq!(ttl, Duration::from_secs(4 * 60));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "ttl=4m,size=2".parse()?;
assert_eq!(size, Some(2));
assert_eq!(absolute_ttl, Some(Duration::from_secs(4 * 60)));
let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::from_secs(1));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=0,ttl=1s".parse()?;
assert_eq!(size, Some(0));
assert_eq!(absolute_ttl, Some(Duration::from_secs(1)));
let CacheOptions { size, ttl } = "size=0".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::default());
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=0".parse()?;
assert_eq!(size, Some(0));
assert_eq!(absolute_ttl, None);
Ok(())
}

View File

@@ -3,7 +3,6 @@
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use ::http::HeaderName;
use ::http::header::AUTHORIZATION;
@@ -17,6 +16,8 @@ use tracing::{Instrument, debug, info, info_span, warn};
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::context::RequestContext;
use crate::control_plane::caches::ApiCaches;
use crate::control_plane::errors::{
@@ -25,8 +26,7 @@ use crate::control_plane::errors::{
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl,
};
use crate::metrics::Metrics;
use crate::proxy::retry::CouldRetry;
@@ -118,7 +118,6 @@ impl NeonControlPlaneClient {
cache_key.into(),
role.into(),
msg.clone(),
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
);
Err(err)
@@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
) -> Result<RoleAccessControl, GetAuthInfoError> {
let key = endpoint.normalize();
if let Some((role_control, ttl)) = self
.caches
.project_info
.get_role_secret_with_ttl(&key, role)
{
if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) {
return match role_control {
Err(mut msg) => {
Err(msg) => {
info!(key = &*key, "found cached get_role_access_control error");
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(role_control) => {
@@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let key = endpoint.normalize();
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
if let Some(control) = self.caches.project_info.get_endpoint_access(&key) {
return match control {
Err(mut msg) => {
Err(msg) => {
info!(
key = &*key,
"found cached get_endpoint_access_control error"
);
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(control) => {
@@ -426,17 +415,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
macro_rules! check_cache {
() => {
if let Some(cached) = self.caches.node_info.get_with_created_at(&key) {
let (cached, (info, created_at)) = cached.take_value();
if let Some(info) = self.caches.node_info.get_entry(&key) {
return match info {
Err(mut msg) => {
Err(msg) => {
info!(key = &*key, "found cached wake_compute error");
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
replace_retry_delay_ms(&mut msg, |delay| {
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
});
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
msg,
)))
@@ -444,7 +427,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
Ok(info) => {
debug!(key = &*key, "found cached compute node info");
ctx.set_project(info.aux.clone());
Ok(cached.map(|()| info))
Ok(info)
}
};
}
@@ -483,10 +466,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
let mut stored_node = node.clone();
// store the cached node as 'warm_cached'
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
self.caches.node_info.insert(key.clone(), Ok(stored_node));
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
Ok(cached.map(|()| node))
Ok(Cached {
token: Some((&self.caches.node_info, key)),
value: node,
})
}
Err(err) => match err {
WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => {
@@ -503,11 +488,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
"created a cache entry for the wake compute error"
);
let ttl = retry_info.map_or(Duration::from_secs(30), |r| {
Duration::from_millis(r.retry_delay_ms)
});
self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl);
self.caches.node_info.insert(key, Err(msg.clone()));
Err(err)
}
@@ -517,14 +498,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
}
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
if let Some(status) = &mut msg.status
&& let Some(retry_info) = &mut status.details.retry_info
{
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
}
}
/// Parse http response body, taking status code into account.
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,

View File

@@ -15,6 +15,7 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
@@ -22,8 +23,7 @@ use crate::control_plane::errors::{
};
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;

View File

@@ -13,10 +13,11 @@ use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
use crate::cache::project_info::ProjectInfoCache;
use crate::config::{CacheOptions, ProjectInfoCacheOptions};
use crate::context::RequestContext;
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
use crate::control_plane::{ControlPlaneApi, errors};
use crate::error::ReportableError;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
@@ -119,7 +120,7 @@ pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub(crate) node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
pub project_info: Arc<ProjectInfoCache>,
}
impl ApiCaches {
@@ -128,13 +129,8 @@ impl ApiCaches {
project_info_cache_config: ProjectInfoCacheOptions,
) -> Self {
Self {
node_info: NodeInfoCache::new(
"node_info_cache",
wake_compute_cache_config.size,
wake_compute_cache_config.ttl,
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
node_info: NodeInfoCache::new(wake_compute_cache_config),
project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)),
}
}
}

View File

@@ -1,8 +1,10 @@
use std::fmt::{self, Display};
use std::time::Duration;
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use tokio::time::Instant;
use crate::auth::IpPattern;
use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
@@ -231,7 +233,13 @@ impl Reason {
#[derive(Copy, Clone, Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct RetryInfo {
pub(crate) retry_delay_ms: u64,
#[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")]
pub(crate) retry_at: Instant,
}
fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Instant, D::Error> {
let millis = u64::deserialize(d)?;
Ok(Instant::now() + Duration::from_millis(millis))
}
#[derive(Debug, Deserialize, Clone)]

View File

@@ -16,13 +16,13 @@ use messages::EndpointRateLimitConfig;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::cache::node_info::CachedNodeInfo;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::types::{EndpointId, RoleName};
use crate::{compute, scram};
/// Various cache-related types.
@@ -77,10 +77,6 @@ pub(crate) struct AccessBlockerFlags {
pub vpc_access_blocked: bool,
}
pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
#[derive(Clone, Debug)]
pub struct RoleAccessControl {
pub secret: Option<AuthSecret>,

View File

@@ -2,59 +2,60 @@ use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
StaticLabelSet,
};
use measured::metric::group::Encoding;
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec, InfoMetric};
use tokio::time::{self, Instant};
use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
#[metric(init = ProxyMetrics::new())]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
#[metric(namespace = "service")]
pub service: ServiceMetrics,
#[metric(namespace = "cache")]
pub cache: CacheMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
SELF.set(metrics)
.ok()
.expect("proxy metrics must not be installed more than once");
}
#[track_caller]
pub fn get() -> &'static Self {
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
static SELF: OnceLock<Metrics> = OnceLock::new();
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
SELF.get_or_init(|| {
let mut metrics = Metrics::new();
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
metrics
})
}
}
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -127,6 +128,9 @@ pub struct ProxyMetrics {
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of SHA 256 rounds executed.
pub sha_rounds: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
@@ -144,8 +148,25 @@ pub struct ProxyMetrics {
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
}
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
impl<T> Default for OnceLockWrapper<T> {
fn default() -> Self {
Self(OnceLock::new())
}
}
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if let Some(inner) = self.0.get() {
inner.collect_group_into(enc)?;
}
Ok(())
}
}
#[derive(MetricGroup)]
@@ -215,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {
Hit,
Miss,
}
#[derive(LabelGroup)]
#[label(set = ConsoleRequestSet)]
pub struct ConsoleRequest<'a> {
@@ -553,14 +567,6 @@ impl From<bool> for Bool {
}
}
#[derive(LabelGroup)]
#[label(set = InvalidEndpointsSet)]
pub struct InvalidEndpointsGroup {
pub protocol: Protocol,
pub rejected: Bool,
pub outcome: ConnectOutcome,
}
#[derive(LabelGroup)]
#[label(set = RetriesMetricSet)]
pub struct RetriesMetricGroup {
@@ -660,3 +666,100 @@ pub struct ThreadPoolMetrics {
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}
#[derive(MetricGroup, Default)]
pub struct ServiceMetrics {
pub info: InfoMetric<ServiceInfo>,
}
#[derive(Default)]
pub struct ServiceInfo {
pub state: ServiceState,
}
impl ServiceInfo {
pub const fn running() -> Self {
ServiceInfo {
state: ServiceState::Running,
}
}
pub const fn terminating() -> Self {
ServiceInfo {
state: ServiceState::Terminating,
}
}
}
impl LabelGroup for ServiceInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const STATE: &LabelName = LabelName::from_str("state");
v.write_value(STATE, &self.state);
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug, Default)]
#[label(singleton = "state")]
pub enum ServiceState {
#[default]
Init,
Running,
Terminating,
}
#[derive(MetricGroup)]
#[metric(new())]
pub struct CacheMetrics {
/// The capacity of the cache
pub capacity: GaugeVec<StaticLabelSet<CacheKind>>,
/// The total number of entries inserted into the cache
pub inserted_total: CounterVec<StaticLabelSet<CacheKind>>,
/// The total number of entries removed from the cache
pub evicted_total: CounterVec<CacheEvictionSet>,
/// The total number of cache requests
pub request_total: CounterVec<CacheOutcomeSet>,
}
impl Default for CacheMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "cache")]
pub enum CacheKind {
NodeInfo,
ProjectInfoEndpoints,
ProjectInfoRoles,
Schema,
Pbkdf2,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
pub enum CacheRemovalCause {
Expired,
Explicit,
Replaced,
Size,
}
#[derive(LabelGroup)]
#[label(set = CacheEvictionSet)]
pub struct CacheEviction {
pub cache: CacheKind,
pub cause: CacheRemovalCause,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum CacheOutcome {
Hit,
Miss,
}
#[derive(LabelGroup)]
#[label(set = CacheOutcomeSet)]
pub struct CacheOutcomeGroup {
pub cache: CacheKind,
pub outcome: CacheOutcome,
}

View File

@@ -2,7 +2,7 @@ use thiserror::Error;
use crate::auth::Backend;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cache;
use crate::cache::common::Cache;
use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError};
use crate::config::ProxyConfig;
use crate::context::RequestContext;

View File

@@ -1,11 +1,12 @@
use tokio::time;
use tracing::{debug, info, warn};
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
@@ -17,7 +18,7 @@ use crate::types::Host;
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
pub(crate) fn invalidate_cache(node_info: CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
@@ -37,7 +38,7 @@ pub(crate) trait ConnectMechanism {
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, compute::ConnectionError>;
}
@@ -66,7 +67,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, compute::ConnectionError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;

View File

@@ -15,15 +15,17 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
use tokio::time::Instant;
use tracing_test::traced_test;
use super::retry::CouldRetry;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
use crate::config::{CacheOptions, ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::control_plane::{self, NodeInfo};
use crate::error::ErrorKind;
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
@@ -417,12 +419,11 @@ impl TestConnectMechanism {
Self {
counter: Arc::new(std::sync::Mutex::new(0)),
sequence,
cache: Box::leak(Box::new(NodeInfoCache::new(
"test",
1,
Duration::from_secs(100),
false,
))),
cache: Box::leak(Box::new(NodeInfoCache::new(CacheOptions {
size: Some(1),
absolute_ttl: Some(Duration::from_secs(100)),
idle_ttl: None,
}))),
}
}
}
@@ -436,7 +437,7 @@ impl ConnectMechanism for TestConnectMechanism {
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_node_info: &CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, compute::ConnectionError> {
let mut counter = self.counter.lock().unwrap();
@@ -501,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
details: Details {
error_info: None,
retry_info: Some(control_plane::messages::RetryInfo {
retry_delay_ms: 1,
retry_at: Instant::now() + Duration::from_millis(1),
}),
user_facing_message: None,
},
@@ -546,8 +547,11 @@ fn helper_create_uncached_node_info() -> NodeInfo {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = helper_create_uncached_node_info();
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
node2.map(|()| node)
cache.insert("key".into(), Ok(node.clone()));
CachedNodeInfo {
token: Some((cache, "key".into())),
value: node,
}
}
fn helper_create_connect_info(

View File

@@ -1,9 +1,9 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::cache::node_info::CachedNodeInfo;
use crate::config::RetryConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::errors::{ControlPlaneError, WakeComputeError};
use crate::error::ReportableError;
use crate::metrics::{

View File

@@ -131,11 +131,11 @@ where
Ok(())
}
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
struct MessageHandler<C: Send + Sync + 'static> {
cache: Arc<C>,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
impl<C: Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
@@ -143,8 +143,8 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>) -> Self {
impl MessageHandler<ProjectInfoCache> {
pub(crate) fn new(cache: Arc<ProjectInfoCache>) -> Self {
Self { cache }
}
@@ -224,7 +224,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
fn invalidate_cache(cache: Arc<ProjectInfoCache>, msg: Notification) {
match msg {
Notification::EndpointSettingsUpdate(ids) => ids
.iter()
@@ -247,8 +247,8 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
}
}
async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
handler: MessageHandler<C>,
async fn handle_messages(
handler: MessageHandler<ProjectInfoCache>,
redis: ConnectionWithCredentialsProvider,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -284,13 +284,10 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
pub async fn task_main(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
cache: Arc<ProjectInfoCache>,
) -> anyhow::Result<Infallible> {
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.

84
proxy/src/scram/cache.rs Normal file
View File

@@ -0,0 +1,84 @@
use tokio::time::Instant;
use zeroize::Zeroize as _;
use super::pbkdf2;
use crate::cache::Cached;
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
impl Cache for Pbkdf2Cache {
type Key = (EndpointIdInt, RoleNameInt);
type Value = Pbkdf2CacheEntry;
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
self.0.invalidate(info);
}
}
/// To speed up password hashing for more active customers, we store the tail results of the
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
/// to determine the final result.
///
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
/// While both are cached in memory, given they're in different locations is makes it much
/// harder to exploit, even if any such memory exploit exists in proxy.
#[derive(Clone)]
pub struct Pbkdf2CacheEntry {
/// corresponds to [`super::ServerSecret::cached_at`]
pub(super) cached_from: Instant,
pub(super) suffix: pbkdf2::Block,
}
impl Drop for Pbkdf2CacheEntry {
fn drop(&mut self) {
self.suffix.zeroize();
}
}
impl Pbkdf2Cache {
pub fn new() -> Self {
const SIZE: u64 = 100;
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
let builder = moka::sync::Cache::builder()
.name("pbkdf2")
.max_capacity(SIZE)
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
.time_to_live(TTL);
Metrics::get()
.cache
.capacity
.set(CacheKind::Pbkdf2, SIZE as i64);
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
Self(builder.build())
}
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
count_cache_insert(CacheKind::Pbkdf2);
self.0.insert((endpoint, role), value);
}
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
}
pub fn get_entry(
&self,
endpoint: EndpointIdInt,
role: RoleNameInt,
) -> Option<CachedPbkdf2<'_>> {
self.get(endpoint, role).map(|value| Cached {
token: Some((self, (endpoint, role))),
value,
})
}
}

View File

@@ -4,10 +4,8 @@ use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tracing::{debug, trace};
use super::ScramKey;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use crate::intern::EndpointIdInt;
use super::{ScramKey, pbkdf2};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{self, ChannelBinding, Error as SaslError};
use crate::scram::cache::Pbkdf2CacheEntry;
/// The only channel binding mode we currently support.
#[derive(Debug)]
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
) -> pbkdf2::Block {
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
}
/// For cleartext flow, we need to derive the client key to
/// 1. authenticate the client.
/// 2. authenticate with compute.
pub(crate) async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
if secret.iterations > CACHED_ROUNDS {
exchange_with_cache(pool, endpoint, role, secret, password).await
} else {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
Ok(validate_pbkdf2(secret, &hash))
}
}
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
/// which is not enough by itself to perform an offline brute force.
async fn exchange_with_cache(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
debug_assert!(
secret.iterations > CACHED_ROUNDS,
"we should not cache password data if there isn't enough rounds needed"
);
// compute the prefix of the pbkdf2 output.
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
// hot path: let's check the threadpool cache
if secret.cached_at == entry.cached_from {
// cache is valid. compute the full hash by adding the prefix to the suffix.
let mut hash = prefix;
pbkdf2::xor_assign(&mut hash, &entry.suffix);
let outcome = validate_pbkdf2(secret, &hash);
if matches!(outcome, sasl::Outcome::Success(_)) {
trace!("password validated from cache");
}
return Ok(outcome);
}
// cached key is no longer valid.
debug!("invalidating cached password");
entry.invalidate();
}
// slow path: full password hash.
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let outcome = validate_pbkdf2(secret, &hash);
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(_) => return Ok(outcome),
};
trace!("storing cached password");
// time to cache, compute the suffix by subtracting the prefix from the hash.
let mut suffix = hash;
pbkdf2::xor_assign(&mut suffix, &prefix);
pool.cache.insert(
endpoint,
role,
Pbkdf2CacheEntry {
cached_from: secret.cached_at,
suffix,
},
);
Ok(sasl::Outcome::Success(client_key))
}
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
let client_key = super::ScramKey::client_key(&(*hash).into());
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
sasl::Outcome::Failure("password doesn't match")
} else {
Ok(sasl::Outcome::Success(client_key))
sasl::Outcome::Success(client_key)
}
}
const CACHED_ROUNDS: u32 = 16;
impl SaslInitial {
fn transition(
&self,

View File

@@ -1,6 +1,12 @@
//! Tools for client/server/stored key management.
use hmac::Mac as _;
use sha2::Digest as _;
use subtle::ConstantTimeEq;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_KEY_LEN: usize = 32;
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Drop for ScramKey {
fn drop(&mut self) {
self.bytes.zeroize();
}
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
impl ScramKey {
pub(crate) fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
Metrics::get().proxy.sha_rounds.inc_by(1);
Self {
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
}
}
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
// Prf::new_from_slice will run 2 sha256 rounds.
// Update + Finalize run 2 sha256 rounds.
Metrics::get().proxy.sha_rounds.inc_by(4);
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
prf.update(b"Client Key");
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
client_key.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {

View File

@@ -6,6 +6,7 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod cache;
mod countmin;
mod exchange;
mod key;
@@ -18,10 +19,8 @@ pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{Mechanism, Step};
use crate::types::EndpointId;
use crate::types::{EndpointId, RoleName};
#[test]
fn snapshot() {
@@ -114,23 +97,34 @@ mod tests {
);
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
async fn check(
pool: &ThreadPool,
scram_secret: &ServerSecret,
password: &[u8],
) -> Result<(), &'static str> {
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let role = RoleName::from("user");
let role = RoleNameInt::from(&role);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
let outcome = super::exchange(pool, ep, role, scram_secret, password)
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
crate::sasl::Outcome::Success(_) => Ok(()),
crate::sasl::Outcome::Failure(r) => Err(r),
}
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
check(&pool, &scram_secret, client_password.as_bytes())
.await
.unwrap();
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await;
@@ -141,4 +135,27 @@ mod tests {
async fn failure() {
run_round_trip_test("pencil", "eraser").await;
}
#[tokio::test]
#[tracing_test::traced_test]
async fn password_cache() {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build("password").await.unwrap();
// wrong passwords are not added to cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("storing cached password"));
// correct passwords get cached
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("storing cached password"));
// wrong passwords do not match the cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("password validated from cache"));
// correct passwords match the cache
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("password validated from cache"));
}
}

View File

@@ -1,25 +1,50 @@
//! For postgres password authentication, we need to perform a PBKDF2 using
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
use hmac::Mac as _;
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
pub type Prf = hmac::Hmac<sha2::Sha256>;
pub(crate) type Block = GenericArray<u8, U32>;
pub(crate) struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
hmac: Prf,
/// U{r-1} for whatever iteration r we are currently on.
prev: Block,
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
hi: Block,
/// number of iterations left
iterations: u32,
}
impl Drop for Pbkdf2 {
fn drop(&mut self) {
self.prev.zeroize();
self.hi.zeroize();
}
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
// key the HMAC and derive the first block in-place
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
// U1 = PRF(Password, Salt + INT_32_BE(i))
// i = 1 since we only need 1 block of output.
hmac.update(salt);
hmac.update(&1u32.to_be_bytes());
let init_block = hmac.finalize_reset().into_bytes();
// Prf::new_from_slice will run 2 sha256 rounds.
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(4);
Self {
hmac,
// one iteration spent above
@@ -33,7 +58,11 @@ impl Pbkdf2 {
(self.iterations).clamp(0, 4096)
}
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
/// function that only executes a fixed number of iterations before continuing.
///
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
let Self {
hmac,
prev,
@@ -44,25 +73,37 @@ impl Pbkdf2 {
// only do up to 4096 iterations per turn for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
hmac.update(prev);
let block = hmac.finalize_reset().into_bytes();
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
*hi_byte ^= b;
}
*prev = block;
let next = single_round(hmac, prev);
xor_assign(hi, &next);
*prev = next;
}
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
std::task::Poll::Ready(*hi)
} else {
std::task::Poll::Pending
}
}
}
#[inline(always)]
pub fn xor_assign(x: &mut Block, y: &Block) {
for (x, &y) in std::iter::zip(x, y) {
*x ^= y;
}
}
#[inline(always)]
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
// Ui = PRF(Password, Ui-1)
prf.update(ui);
prf.finalize_reset().into_bytes()
}
#[cfg(test)]
mod tests {
use pbkdf2::pbkdf2_hmac_array;
@@ -76,11 +117,11 @@ mod tests {
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 60000);
let hash = loop {
let hash: [u8; 32] = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
break hash.into();
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);

View File

@@ -3,6 +3,7 @@
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use tokio::time::Instant;
use super::base64_decode_array;
use super::key::ScramKey;
@@ -11,6 +12,9 @@ use super::key::ScramKey;
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ServerSecret {
/// When this secret was cached.
pub(crate) cached_at: Instant,
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
@@ -34,6 +38,7 @@ impl ServerSecret {
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
cached_at: Instant::now(),
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
@@ -54,6 +59,7 @@ impl ServerSecret {
/// See `auth-scram.c : mock_scram_secret` for details.
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
Self {
cached_at: Instant::now(),
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.

View File

@@ -1,6 +1,10 @@
//! Tools for client/server signature management.
use hmac::Mac as _;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
impl SignatureBuilder<'_> {
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
// don't know exactly. this is a rough approx
Metrics::get().proxy.sha_rounds.inc_by(8);
super::hmac_sha256(key.as_ref(), parts).into()
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
mac.update(self.client_first_message_bare.as_bytes());
mac.update(b",");
mac.update(self.server_first_message.as_bytes());
mac.update(b",");
mac.update(self.client_final_message_without_proof.as_bytes());
Signature {
bytes: mac.finalize().into_bytes().into(),
}
}
}

View File

@@ -15,6 +15,8 @@ use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::cache::Pbkdf2Cache;
use super::pbkdf2;
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
pub metrics: Arc<ThreadPoolMetrics>,
// we hash a lot of passwords.
// we keep a cache of partial hashes for faster validation.
pub(super) cache: Pbkdf2Cache,
}
/// How often to reset the sketch values
@@ -68,6 +74,7 @@ impl ThreadPool {
Self {
runtime: Some(runtime),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
cache: Pbkdf2Cache::new(),
}
})
}
@@ -130,7 +137,7 @@ struct JobSpec {
}
impl Future for JobSpec {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
STATE.with_borrow_mut(|state| {
@@ -166,10 +173,10 @@ impl Future for JobSpec {
}
}
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
impl Future for JobHandle {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
@@ -203,10 +210,10 @@ mod tests {
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await;
let expected = [
let expected = &[
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected);
assert_eq!(actual.as_slice(), expected);
}
}

View File

@@ -4,6 +4,7 @@ use std::time::Duration;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::error::SqlState;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand_core::OsRng;
use tracing::field::display;
@@ -26,7 +27,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{connect_auth, connect_compute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -76,9 +77,11 @@ impl PoolingBackend {
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let role = RoleNameInt::from(&user_info.user);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
&self.config.authentication_config.scram_thread_pool,
ep,
role,
password,
secret,
)
@@ -457,15 +460,14 @@ impl ReportableError for HttpConnError {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => {
if p.as_db_error().is_some() {
// postgres rejected the connection
ErrorKind::Postgres
} else {
// couldn't even reach postgres
ErrorKind::Compute
}
}
HttpConnError::PostgresConnectionError(p) => match p.as_db_error() {
// user provided a wrong database name
Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
// postgres rejected the connection
Some(_) => ErrorKind::Postgres,
// couldn't even reach postgres
None => ErrorKind::Compute,
},
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,

View File

@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use bytes::Bytes;
@@ -12,6 +13,7 @@ use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode};
use indexmap::IndexMap;
use moka::sync::Cache;
use ouroboros::self_referencing;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer};
@@ -53,12 +55,12 @@ use super::http_util::{
};
use super::json::JsonConversionError;
use crate::auth::backend::ComputeCredentialKeys;
use crate::cache::{Cached, TimedLru};
use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::read_body_with_limit;
use crate::metrics::Metrics;
use crate::metrics::{CacheKind, Metrics};
use crate::serverless::sql_over_http::HEADER_VALUE_TRUE;
use crate::types::EndpointCacheKey;
use crate::util::deserialize_json_string;
@@ -138,8 +140,31 @@ pub struct ApiConfig {
}
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
pub(crate) type DbSchemaCache = TimedLru<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>;
pub(crate) struct DbSchemaCache(Cache<EndpointCacheKey, Arc<(ApiConfig, DbSchemaOwned)>>);
impl DbSchemaCache {
pub fn new(config: crate::config::CacheOptions) -> Self {
let builder = Cache::builder().name("schema");
let builder = config.moka(builder);
let metrics = &Metrics::get().cache;
if let Some(size) = config.size {
metrics.capacity.set(CacheKind::Schema, size as i64);
}
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause));
Self(builder.build())
}
pub async fn maintain(&self) -> Result<Infallible, anyhow::Error> {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
loop {
ticker.tick().await;
self.0.run_pending_tasks();
}
}
pub async fn get_cached_or_remote(
&self,
endpoint_id: &EndpointCacheKey,
@@ -149,8 +174,9 @@ impl DbSchemaCache {
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
match self.get_with_created_at(endpoint_id) {
Some(Cached { value: (v, _), .. }) => Ok(v),
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
match cache_result {
Some(v) => Ok(v),
None => {
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
let remote_value = self
@@ -173,7 +199,8 @@ impl DbSchemaCache {
db_extra_search_path: None,
};
let value = Arc::new((api_config, schema_owned));
self.insert(endpoint_id.clone(), value);
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value);
return Err(e);
}
Err(e) => {
@@ -181,7 +208,8 @@ impl DbSchemaCache {
}
};
let value = Arc::new((api_config, schema_owned));
self.insert(endpoint_id.clone(), value.clone());
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value.clone());
Ok(value)
}
}

View File

@@ -192,34 +192,29 @@ pub(crate) async fn handle(
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
match &e {
SqlOverHttpError::Postgres(e)
if e.as_db_error().is_some() && error_kind == ErrorKind::User =>
{
// this error contains too much info, and it's not an error we care about.
if tracing::enabled!(Level::DEBUG) {
tracing::debug!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
} else {
tracing::info!(
kind = error_kind.to_metric_label(),
error = "bad query",
"forwarding error to user"
);
}
}
_ => {
tracing::info!(
if db_error.is_some() && error_kind == ErrorKind::User {
// this error contains too much info, and it's not an error we care about.
if tracing::enabled!(Level::DEBUG) {
debug!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
} else {
info!(
kind = error_kind.to_metric_label(),
error = "bad query",
"forwarding error to user"
);
}
} else {
info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
}
json_response(

View File

@@ -4,6 +4,8 @@ use anyhow::bail;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::metrics::{Metrics, ServiceInfo};
/// Handle unix signals appropriately.
pub async fn handle<F>(
token: CancellationToken,
@@ -28,10 +30,12 @@ where
// Shut down the whole application.
_ = interrupt.recv() => {
warn!("received SIGINT, exiting immediately");
Metrics::get().service.info.set_label(ServiceInfo::terminating());
bail!("interrupted");
}
_ = terminate.recv() => {
warn!("received SIGTERM, shutting down once all existing connections have closed");
Metrics::get().service.info.set_label(ServiceInfo::terminating());
token.cancel();
}
}

View File

@@ -102,7 +102,7 @@ pub struct ReportedError {
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),

View File

@@ -12,7 +12,7 @@ use futures::stream::{self, FuturesOrdered};
use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr;
use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo};
use remote_storage::{
DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
};
use safekeeper_api::models::PeerInfo;
use tokio::fs::File;
@@ -607,6 +607,9 @@ pub(crate) async fn copy_partial_segment(
storage.copy_object(source, destination, &cancel).await
}
const WAL_READ_WARN_THRESHOLD: u32 = 2;
const WAL_READ_MAX_RETRIES: u32 = 3;
pub async fn read_object(
storage: &GenericRemoteStorage,
file_path: &RemotePath,
@@ -620,12 +623,23 @@ pub async fn read_object(
byte_start: std::ops::Bound::Included(offset),
..Default::default()
};
let download = storage
.download(file_path, &opts, &cancel)
.await
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")
})?;
// This retry only solves the connect errors: subsequent reads can still fail as this function returns
// a stream.
let download = backoff::retry(
|| async { storage.download(file_path, &opts, &cancel).await },
DownloadError::is_permanent,
WAL_READ_WARN_THRESHOLD,
WAL_READ_MAX_RETRIES,
"download WAL segment",
&cancel,
)
.await
.ok_or_else(|| DownloadError::Cancelled)
.and_then(|x| x)
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")
})?;
let reader = tokio_util::io::StreamReader::new(download.download_stream);

View File

@@ -74,4 +74,4 @@ http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -0,0 +1,7 @@
/// Type of the storage node (pageserver or safekeeper) that we are updating DNS records for. Different types of nodes will have
/// different-looking DNS names in the DNS zone.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum NodeType {
Pageserver,
Safekeeper,
}

View File

@@ -0,0 +1,433 @@
#![allow(dead_code, unused)]
use std::collections::{HashMap, HashSet};
use diesel::Queryable;
use diesel::dsl::min;
use diesel::prelude::*;
use diesel_async::AsyncConnection;
use diesel_async::AsyncPgConnection;
use diesel_async::RunQueryDsl;
use itertools::Itertools;
use pageserver_api::controller_api::SCSafekeeperTimelinesResponse;
use scoped_futures::ScopedFutureExt;
use serde::{Deserialize, Serialize};
use utils::id::{NodeId, TenantId, TimelineId};
use uuid::Uuid;
use crate::hadron_dns::NodeType;
use crate::hadron_requests::NodeConnectionInfo;
use crate::persistence::{DatabaseError, DatabaseResult};
use crate::schema::{hadron_safekeepers, nodes};
use crate::sk_node::SafeKeeperNode;
use std::str::FromStr;
// The Safe Keeper node database representation (for Diesel).
#[derive(
Clone, Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq, AsChangeset,
)]
#[diesel(table_name = crate::schema::hadron_safekeepers)]
pub(crate) struct HadronSafekeeperRow {
pub(crate) sk_node_id: i64,
pub(crate) listen_http_addr: String,
pub(crate) listen_http_port: i32,
pub(crate) listen_pg_addr: String,
pub(crate) listen_pg_port: i32,
}
#[derive(
Clone, Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq, AsChangeset,
)]
#[diesel(table_name = crate::schema::hadron_timeline_safekeepers)]
pub(crate) struct HadronTimelineSafekeeper {
pub(crate) timeline_id: String,
pub(crate) sk_node_id: i64,
pub(crate) legacy_endpoint_id: Option<Uuid>,
}
pub async fn execute_sk_upsert(
conn: &mut AsyncPgConnection,
sk_row: HadronSafekeeperRow,
) -> DatabaseResult<()> {
// SQL:
// INSERT INTO hadron_safekeepers (sk_node_id, listen_http_addr, listen_http_port, listen_pg_addr, listen_pg_port)
// VALUES ($1, $2, $3, $4, $5)
// ON CONFLICT (sk_node_id)
// DO UPDATE SET listen_http_addr = $2, listen_http_port = $3, listen_pg_addr = $4, listen_pg_port = $5;
use crate::schema::hadron_safekeepers::dsl::*;
diesel::insert_into(hadron_safekeepers)
.values(&sk_row)
.on_conflict(sk_node_id)
.do_update()
.set(&sk_row)
.execute(conn)
.await?;
Ok(())
}
// Load all safekeeper nodes and their associated timelines from the meta PG. This query is supposed
// to run only once on HCC startup and is used to construct the SafeKeeperScheduler state. Performs
// scans of the hadron_safekeepers and hadron_timeline_safekeepers tables.
pub async fn scan_safekeepers_and_scheduled_timelines(
conn: &mut AsyncPgConnection,
) -> DatabaseResult<HashMap<NodeId, SafeKeeperNode>> {
use crate::schema::hadron_safekeepers;
use crate::schema::hadron_timeline_safekeepers;
// We first scan the hadron_safekeepers table to constuct the SafeKeeperNode objects. We don't know anything about
// the timelines scheduled to the safekeepers after this step. We then scan the hadron_timeline_safekeepers table
// to populate the data structures in the SafeKeeperNode objects to reflect the timelines scheduled to the safekeepers.
let mut results: HashMap<NodeId, SafeKeeperNode> = hadron_safekeepers::table
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_http_addr,
hadron_safekeepers::listen_http_port,
hadron_safekeepers::listen_pg_addr,
hadron_safekeepers::listen_pg_port,
))
.load::<HadronSafekeeperRow>(conn)
.await?
.into_iter()
.map(|row| {
let sk_node = SafeKeeperNode {
id: NodeId(row.sk_node_id as u64),
listen_http_addr: row.listen_http_addr.clone(),
listen_http_port: row.listen_http_port as u16,
listen_pg_addr: row.listen_pg_addr.clone(),
listen_pg_port: row.listen_pg_port as u16,
legacy_endpoints: HashMap::new(),
timelines: HashSet::new(),
};
(sk_node.id, sk_node)
})
.collect();
let timeline_sk_rows = hadron_timeline_safekeepers::table
.select((
hadron_timeline_safekeepers::sk_node_id,
hadron_timeline_safekeepers::timeline_id,
hadron_timeline_safekeepers::legacy_endpoint_id,
))
.load::<(i64, String, Option<Uuid>)>(conn)
.await?;
for (sk_node_id, timeline_id, legacy_endpoint_id) in timeline_sk_rows {
if let Some(sk_node) = results.get_mut(&NodeId(sk_node_id as u64)) {
let parsed_timeline_id =
TimelineId::from_str(&timeline_id).map_err(|e: hex::FromHexError| {
DatabaseError::Logical(format!("Failed to parse timeline IDs: {e}"))
})?;
sk_node.timelines.insert(parsed_timeline_id);
if let Some(legacy_endpoint_id) = legacy_endpoint_id {
sk_node
.legacy_endpoints
.insert(legacy_endpoint_id, parsed_timeline_id);
}
}
}
Ok(results)
}
// Queries the hadron_timeline_safekeepers table to get the safekeepers assigned to the passed
// timeline. If none are found, persists the input proposed safekeepers to the table and returns
// them.
pub async fn idempotently_persist_or_get_existing_timeline_safekeepers(
conn: &mut AsyncPgConnection,
timeline_id: TimelineId,
safekeepers: &[NodeId],
) -> DatabaseResult<Vec<NodeId>> {
use crate::schema::hadron_timeline_safekeepers;
// Confirm and persist the timeline-safekeeper mapping. If there are existing safekeepers
// assigned to the timeline in the database, treat those as the source of truth.
let existing_safekeepers: Vec<i64> = hadron_timeline_safekeepers::table
.select(hadron_timeline_safekeepers::sk_node_id)
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.load::<i64>(conn)
.await?;
let confirmed_safekeepers: Vec<NodeId> = if existing_safekeepers.is_empty() {
let proposed_safekeeper_endpoint_rows_result: Result<Vec<HadronTimelineSafekeeper>, _> =
safekeepers
.iter()
.map(|sk_node_id| {
i64::try_from(sk_node_id.0).map(|sk_node_id| HadronTimelineSafekeeper {
timeline_id: timeline_id.to_string(),
sk_node_id,
legacy_endpoint_id: None,
})
})
.collect();
let proposed_safekeeper_endpoint_rows =
proposed_safekeeper_endpoint_rows_result.map_err(|e| {
DatabaseError::Logical(format!("Failed to convert safekeeper IDs: {e}"))
})?;
diesel::insert_into(hadron_timeline_safekeepers::table)
.values(&proposed_safekeeper_endpoint_rows)
.execute(conn)
.await?;
safekeepers.to_owned()
} else {
let safekeeper_result: Result<Vec<NodeId>, _> = existing_safekeepers
.into_iter()
.map(|arg0: i64| u64::try_from(arg0).map(NodeId))
.collect();
safekeeper_result
.map_err(|e| DatabaseError::Logical(format!("Failed to convert safekeeper IDs: {e}")))?
};
Ok(confirmed_safekeepers)
}
pub async fn delete_timeline_safekeepers(
conn: &mut AsyncPgConnection,
timeline_id: TimelineId,
) -> DatabaseResult<()> {
use crate::schema::hadron_timeline_safekeepers;
diesel::delete(hadron_timeline_safekeepers::table)
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.execute(conn)
.await?;
Ok(())
}
pub(crate) async fn execute_safekeeper_list_timelines(
conn: &mut AsyncPgConnection,
safekeeper_id: i64,
) -> DatabaseResult<SCSafekeeperTimelinesResponse> {
use crate::schema::hadron_timeline_safekeepers;
use pageserver_api::controller_api::SCSafekeeperTimelinesResponse;
conn.transaction(|conn| {
async move {
let mut sk_timelines = SCSafekeeperTimelinesResponse {
timelines: Vec::new(),
safekeeper_peers: Vec::new(),
};
// Find all timelines <String>
let timeline_ids = hadron_timeline_safekeepers::table
.select(hadron_timeline_safekeepers::timeline_id)
.filter(hadron_timeline_safekeepers::sk_node_id.eq(safekeeper_id))
.load::<String>(conn)
.await
.into_iter()
.flatten()
.collect_vec();
// Find the peers for each timeline. <timeline_id, sk_node_id>
let timeline_peers = hadron_timeline_safekeepers::table
.select((
hadron_timeline_safekeepers::timeline_id,
hadron_timeline_safekeepers::sk_node_id,
))
.filter(hadron_timeline_safekeepers::timeline_id.eq_any(&timeline_ids))
.load::<(String, i64)>(conn)
.await
.into_iter()
.flatten()
.collect_vec();
let mut timeline_peers_map = HashMap::new();
let mut seen = HashSet::new();
let mut unique_sks = Vec::new();
for (timeline_id, sk_node_id) in timeline_peers {
timeline_peers_map
.entry(timeline_id)
.or_insert_with(Vec::new)
.push(sk_node_id);
if seen.insert(sk_node_id) {
unique_sks.push(sk_node_id);
}
}
// Find SK info.
let mut found_sk_nodes = HashSet::new();
hadron_safekeepers::table
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_http_addr,
hadron_safekeepers::listen_http_port,
))
.filter(hadron_safekeepers::sk_node_id.eq_any(&unique_sks))
.load::<(i64, String, i32)>(conn)
.await
.into_iter()
.flatten()
.for_each(|(sk_node_id, listen_http_addr, http_port)| {
found_sk_nodes.insert(sk_node_id);
sk_timelines.safekeeper_peers.push(
pageserver_api::controller_api::TimelineSafekeeperPeer {
node_id: utils::id::NodeId(sk_node_id as u64),
listen_http_addr,
http_port,
},
);
});
// Prepare timeline response.
for timeline_id in timeline_ids {
if !timeline_peers_map.contains_key(&timeline_id) {
continue;
}
let peers = timeline_peers_map.get(&timeline_id).unwrap();
// Check peers exist.
if !peers
.iter()
.all(|sk_node_id| found_sk_nodes.contains(sk_node_id))
{
continue;
}
let timeline = pageserver_api::controller_api::SCSafekeeperTimeline {
timeline_id: TimelineId::from_str(&timeline_id).unwrap(),
peers: peers
.iter()
.map(|sk_node_id| utils::id::NodeId(*sk_node_id as u64))
.collect(),
};
sk_timelines.timelines.push(timeline);
}
Ok(sk_timelines)
}
.scope_boxed()
})
.await
}
/// Stores details about connecting to pageserver and safekeeper nodes for a given tenant and
/// timeline.
pub struct PageserverAndSafekeeperConnectionInfo {
pub pageserver_conn_info: Vec<NodeConnectionInfo>,
pub safekeeper_conn_info: Vec<NodeConnectionInfo>,
}
/// Retrieves the connection information for the pageserver and safekeepers associated with the
/// given tenant and timeline.
pub async fn get_pageserver_and_safekeeper_connection_info(
conn: &mut AsyncPgConnection,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> DatabaseResult<PageserverAndSafekeeperConnectionInfo> {
conn.transaction(|conn| {
async move {
// Fetch details about pageserver, which is associated with the input tenant.
let pageserver_conn_info =
get_pageserver_connection_info(conn, &tenant_id.to_string()).await?;
// Fetch details about safekeepers, which are associated with the input timeline.
let safekeeper_conn_info =
get_safekeeper_connection_info(conn, &timeline_id.to_string()).await?;
Ok(PageserverAndSafekeeperConnectionInfo {
pageserver_conn_info,
safekeeper_conn_info,
})
}
.scope_boxed()
})
.await
}
async fn get_safekeeper_connection_info(
conn: &mut AsyncPgConnection,
timeline_id: &str,
) -> DatabaseResult<Vec<NodeConnectionInfo>> {
use crate::schema::hadron_safekeepers;
use crate::schema::hadron_timeline_safekeepers;
Ok(hadron_timeline_safekeepers::table
.inner_join(
hadron_safekeepers::table
.on(hadron_timeline_safekeepers::sk_node_id.eq(hadron_safekeepers::sk_node_id)),
)
.select((
hadron_safekeepers::sk_node_id,
hadron_safekeepers::listen_pg_addr,
hadron_safekeepers::listen_pg_port,
))
.filter(hadron_timeline_safekeepers::timeline_id.eq(timeline_id.to_string()))
.load::<(i64, String, i32)>(conn)
.await?
.into_iter()
.map(|(node_id, addr, port)| {
NodeConnectionInfo::new(
NodeType::Safekeeper,
NodeId(node_id as u64),
addr,
port as u16,
)
})
.collect())
}
async fn get_pageserver_connection_info(
conn: &mut AsyncPgConnection,
tenant_id: &str,
) -> DatabaseResult<Vec<NodeConnectionInfo>> {
use crate::schema::tenant_shards;
// When the tenant is being split, it'll contain both old shards and new shards. Until the tenant split is committed,
// we should always use the old shards.
// NOTE: we only support tenant split without tennat merge. Thus shard count could only increase.
let min_shard_count = match tenant_shards::table
.select(min(tenant_shards::shard_count))
.filter(tenant_shards::tenant_id.eq(tenant_id))
.first::<Option<i32>>(conn)
.await
.optional()?
{
Some(Some(count)) => count,
Some(None) => {
// Tenant doesn't exist. It's possible that it was deleted before we got the request.
return Ok(vec![]);
}
None => {
// This is never supposed to happen because `SELECT min()` should always return one row.
return Err(DatabaseError::Logical(format!(
"Unexpected empty query result for min(shard_count) query. Tenant ID {tenant_id}"
)));
}
};
let shards: Vec<NodeConnectionInfo> = nodes::table
.inner_join(
tenant_shards::table.on(nodes::node_id
.nullable()
.eq(tenant_shards::generation_pageserver)),
)
.select((nodes::node_id, nodes::listen_pg_addr, nodes::listen_pg_port))
.filter(tenant_shards::tenant_id.eq(&tenant_id.to_string()))
.order(tenant_shards::shard_number.asc())
.filter(tenant_shards::shard_count.eq(min_shard_count))
.load::<(i64, String, i32)>(conn)
.await?
.into_iter()
.map(|(node_id, addr, port)| {
NodeConnectionInfo::new(
NodeType::Pageserver,
NodeId(node_id as u64),
addr,
port as u16,
)
})
.collect();
if !shards.is_empty() && !shards.len().is_power_of_two() {
return Err(DatabaseError::Logical(format!(
"Tenant {} has unexpected shard count {} (not a power of 2)",
tenant_id,
shards.len()
)));
}
Ok(shards)
}

View File

@@ -0,0 +1,34 @@
use utils::id::NodeId;
use crate::hadron_dns::NodeType;
/// Internal representation of how a compute node should connect to a PS or SK node. HCC uses this struct to
/// construct connection strings that are passed to the compute node via the compute spec. This struct is never
/// serialized or sent over the wire.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct NodeConnectionInfo {
// Type of the node.
node_type: NodeType,
// Node ID. Unique for each node type.
pub(crate) node_id: NodeId,
// The hostname reported by the node when it registers. This is the hostname we store in the meta PG, and is
// typically the k8s cluster DNS name of the node. Note that this may not be resolvable from compute nodes running
// on dblet. For this reason, this hostname is usually not communicated to the compute node. Instead, HCC computes
// a DNS name of the node in the Cloud DNS hosted zone based on `node_type` and `node_id` and advertise the DNS name
// to compute nodes. This hostname here is used as a fallback in tests or other scenarios where we do not have the
// Cloud DNS hosted zone available.
registration_hostname: String,
// The PG wire protocol port on the PS or SK node.
port: u16,
}
impl NodeConnectionInfo {
pub(crate) fn new(node_type: NodeType, node_id: NodeId, hostname: String, port: u16) -> Self {
NodeConnectionInfo {
node_type,
node_id,
registration_hostname: hostname,
port,
}
}
}

Some files were not shown because too many files have changed in this diff Show More