Compare commits

..

64 Commits

Author SHA1 Message Date
Heikki Linnakangas
c18b291244 Refactor cancellation Session to be more flexible.
Instead of with_session that calls a Future with the session, have a
more conventional constructor function, `new_session`, which returns a
Session. The session is automatically removed from the cancellation
map in Drop. This makes it nicer to use.
2023-03-13 11:27:03 +02:00
Joonas Koivunen
8699342249 Ondemand rx bytes and layer count (#3777)
Adds two new *global* metrics:
- pageserver_remote_ondemand_downloaded_layers_total
- pageserver_remote_ondemand_downloaded_bytes_total

An existing test is repurposed once more to check that we do get some
reasonable counts. These are to replace guessing from the nic RX bytes
metric how much was on-demand downloaded.

First part of #3745: This does not add the "(un)?avoidable" metric,
which I plan to add as a new metric, which will be a subset of the
counts of the metrics added here.
2023-03-13 09:26:49 +02:00
Joonas Koivunen
ce8fbbd910 Fix allowed error again (#3790)
Fixes #3360 again, this time checking all other "Error processing HTTP
request" messages and aligning the regex with the two others.
2023-03-10 17:44:12 +00:00
Vadim Kharitonov
1401021b21 Be able to get number of CPUs (#3774)
After enabling autoscaling, we faced the issue that customers are not
able to get the number of CPUs they use at this moment. Therefore I've
added these two options:

1. Postgresql function to allow customers to call it whenever they want
2. `compute_ctl` endpoint to show these number in console
2023-03-10 19:00:20 +02:00
Stas Kelvich
252b3685a2 Use unsafe-postgres feature to build pgx extension
Recently added `unsafe-postgres` feature allows to build pgx extensions
against postgres forks that decided to change their ABI name (like us).
With that we can build extensions without forking them and using stock
pgx. As this feature is new few manual version bumps were required.
2023-03-10 17:40:45 +02:00
Heikki Linnakangas
34d3385b2e Add unit tests for JWT encoding and decoding. 2023-03-10 16:09:32 +02:00
Heikki Linnakangas
b00530df2a Add section in internal docs on the JWT payload.
Just copied from the code comments. Could be improved, but this is a
start.
2023-03-10 16:09:32 +02:00
Heikki Linnakangas
bebf76c461 Accept RS384 and RS512 JWT tokens.
Previously, we only accepted RS256. Seems like a pointless limitation,
and when I was testing it with RS512 tokens, it took me a while to
understand why it wasn't working.
2023-03-10 16:09:32 +02:00
Vadim Kharitonov
2ceef91da1 Compile pg_tiktoken extension 2023-03-10 14:59:26 +01:00
Anastasia Lubennikova
b7fddfa70d Add branch_id field to proxy_io_bytes_per_client metric.
Since we allow switching endpoints between different branches, it is important to use composite key.
Otherwise, we may try to calculate delta between metric values for two different branches.
2023-03-10 15:00:48 +02:00
Heikki Linnakangas
d1537a49fa Fix escaping in postgresql.conf that we generate at compute startup
If there are any config options that contain single quotes or backslashes,
they need to be escaped
2023-03-10 14:59:21 +02:00
Heikki Linnakangas
856d01ff68 Add newline at end of postgresql.conf 2023-03-10 14:59:21 +02:00
Heikki Linnakangas
42ec79fb0d Make expected test output nicer to read.
By using Rust raw string literal.
2023-03-10 14:59:21 +02:00
Rory de Zoete
3c4f5af1b9 Try depot.dev for image building (#3768)
To see if it is faster. Run side-by-side for a while so we can gather
enough data.
2023-03-10 11:11:39 +01:00
Arseny Sher
290884ea3b Fix too many arguments in read_network clippy complain. 2023-03-10 10:50:03 +03:00
Arseny Sher
965837df53 Log connection ids in safekeeper instead of thread ids.
Fixes build on macOS (which doesn't have nix gettid) after 0d8ced8534.
2023-03-10 10:50:03 +03:00
Arseny Sher
d1a0f2f0eb Fix example why manual-range-contains is disabled. 2023-03-10 00:03:17 +03:00
Konstantin Knizhnik
a34e78d084 Retry attempt to connect to pageserver in order to make pageserver restart transparent for clients (#3700)
…start transparent for clients

## Describe your changes

Try to reestablish connection with pageserver if send is failed to be
able to make pageserver restart transparent for client

## Issue ticket number and link

https://github.com/neondatabase/neon/issues/1138

## Checklist before requesting a review
- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

---------

Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
2023-03-09 22:15:46 +02:00
Arseny Sher
b80fe41af3 Refactor postgres protocol parsing.
1) Remove allocation and data copy during each message read. Instead, parsing
functions now accept BytesMut from which data they form messages, with
pointers (e.g. in CopyData) pointing directly into BytesMut buffer. Accordingly,
move ConnectionError containing IO error subtype into framed.rs providing this
and leave in pq_proto only ProtocolError.

2) Remove anyhow from pq_proto.

3) Move FeStartupPacket out of FeMessage. Now FeStartupPacket::parse returns it
directly, eliminating dead code where user wants startup packet but has to match
for others.

proxy stream.rs is adapted to framed.rs with minimal changes. It also benefits
from framed.rs improvements described above.
2023-03-09 20:45:56 +03:00
Arseny Sher
0d8ced8534 Remove sync postgres_backend, tidy up its split usage.
- Add support for splitting async postgres_backend into read and write halfes.
  Safekeeper needs this for bidirectional streams. To this end, encapsulate
  reading-writing postgres messages to framed.rs with split support without any
  additional changes (relying on BufRead for reading and BytesMut out buffer for
  writing).
- Use async postgres_backend throughout safekeeper (and in proxy auth link
  part).
- In both safekeeper COPY streams, do read-write from the same thread/task with
  select! for easier error handling.
- Tidy up finishing CopyBoth streams in safekeeper sending and receiving WAL
  -- join split parts back catching errors from them before returning.

Initially I hoped to do that read-write without split at all, through polling
IO:
https://github.com/neondatabase/neon/pull/3522
However that turned out to be more complicated than I initially expected
due to 1) borrow checking and 2) anon Future types. 1) required Rc<Refcell<...>>
which is Send construct just to satisfy the checker; 2) can be workaround with
transmute. But this is so messy that I decided to leave split.
2023-03-09 20:45:56 +03:00
Arseny Sher
7627d85345 Move async postgres_backend to its own crate.
To untie cyclic dependency between sync and async versions of postgres_backend,
copy QueryError and some logging/error routines to postgres_backend.rs. This is
temporal glue to make commits smaller, sync version will be dropped by the
upcoming commit completely.
2023-03-09 20:45:56 +03:00
Arseny Sher
3f11a647c0 Rename write_message to write_message_noflush in postgres_backend_async.rs
To make it unifrom across the project; proxy stream.rs and older
postgres_backend uses write_message_noflush.
2023-03-09 20:45:56 +03:00
Alexey Kondratov
e43c413a3f [compute_tools] Add /insights endpoint to compute_ctl (#3704)
This commit adds a basic HTTP API endpoint that allows scraping the
`pg_stat_statements` data and getting a list of slow queries. New
insights like cache hit rate and so on could be added later.

Extension `pg_stat_statements` is checked / created only if compute
tries to load the corresponding shared library. The latter is configured
by control-plane and currently covered with feature flag.

Co-authored by Eduard Dyckman (bird.duskpoet@gmail.com)
2023-03-09 14:21:10 +01:00
Heikki Linnakangas
8459e0265e Add performance test for compaction and image layer creation 2023-03-09 14:30:12 +02:00
Kirill Bulatov
03a2ce9d13 Add tracing spans with request_id into pageserver management API handlers (#3755)
Adds a newtype that creates a span with request_id from
https://github.com/neondatabase/neon/pull/3708 for every HTTP request
served.

Moves request logging and error handlers under the new wrapper, so every request-related event now is logged under the request span.
For compatibility reasons, error handler is left on the general router, since not every service uses the new handler wrappers yet.
2023-03-09 09:24:01 +02:00
Heikki Linnakangas
ccf92df4da Remove deprecated support to handle ZENITH_AUTH_TOKEN.
It's not used anywhere anymore.
2023-03-09 00:53:13 +02:00
Vadim Kharitonov
37bc6d9be4 Compile plpgsql_check extension 2023-03-08 23:20:24 +01:00
Vadim Kharitonov
177f986795 Compile hll extension 2023-03-08 12:03:09 +01:00
Heikki Linnakangas
fb1581d0b9 Fix setting "image_creation_threshold" setting in tenant config. (#3762)
We have a few tests that try to set image_creation_threshold, but it
didn't actually have any effect because we were missing some critical
code to load the setting from config file into memory.

The two modified tests in `test_remote_storage.py perform
compaction and GC, and assert that GC removes some layers. That
only happens if new image layers are created by the
compaction. The tests explicitly disabled image layer creation by
setting image_creation_threshold to a high value, but it didn't
take effect because reading image_creation_threshold from config
file was broken, which is why the test worked. Fix the test to
set image_creation_threshold low, instead, so that GC has work to
do.

Change 'test_tenant_conf.py' so that it exercises the added code.

This might explain why we're apparently missing test coverage for GC
(issue #3415), although I didn't try to address that here, nor did I
check if this improves the it.
2023-03-08 11:39:30 +02:00
Sasha Krassovsky
02b8e0e5af Add OpenAPI spec for do_gc (#3756)
## Describe your changes
Adds a field to the OpenAPI spec for the page server which describes the
`do_gc` command.
## Issue ticket number and link
#3669
## Checklist before requesting a review
- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
2023-03-07 09:08:46 -08:00
Vadim Kharitonov
1b16de0d0f Compile prefix extension 2023-03-07 17:42:53 +01:00
Stas Kelvich
069b5b0a06 Make postgres --wal-redo more embeddable.
* Stop allocating and maintaining 128MB hash table for last written

  LSN cache as it is not needed in wal-redo.

* Do not require access to the initialized data directory. That
  saves few dozens megabytes of empty but initialized data directory.
  Currently such directories do occupy about 10% of the disk space
  on the pageservers as most of tenants are empty.

* Move shmem-initialization code to the extension instead of postgres
2023-03-07 15:01:14 +02:00
Joonas Koivunen
b05e94e4ff fix: allow ERROR log to appear per allowed failure (#3696)
The test already allows the background thread trying to checkpoint to
fail, however the resulting log message is currently not allowed thus
causing flakyness.
2023-03-07 12:44:04 +00:00
Arseny Sher
0acf9ace9a Return 404 if timeline is not found in safekeeper HTTP API. 2023-03-07 16:34:20 +04:00
Arseny Sher
ca85646df4 Max peer_horizon_lsn before adopting it.
Before this patch, persistent peer_horizon_lsn was always sent to walproposer,
making it initially calculate it equal to max of persistent values and in turn
pulling back the in memory value. Send instead in memory value and take max when
safekeeper sets it.

closes https://github.com/neondatabase/neon/issues/3752
2023-03-07 10:16:54 +04:00
Shany Pozin
7b9057ad01 Add timeout to download copy (#3675)
## Describe your changes
Adding a timeout handling for the remote download of layers of 120
seconds for each operation
Note that these downloads are being retried for N times 

## Issue ticket number and link
Fixes: #3672

## Checklist before requesting a review
- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

---------

Co-authored-by: Joonas Koivunen <joonas@neon.tech>
2023-03-06 18:52:59 +02:00
Konstantin Knizhnik
96f65fad68 Handle crash of walredo process and retry applying wal records (#3739)
## Describe your changes

Restart walredo process an d retry applying walredo records i case of
abnormal walredo process termination

## Issue ticket number and link

See #1700


## Checklist before requesting a review
- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
2023-03-06 10:10:58 +02:00
Christian Schwarz
9cada8b59d fix benchmarks, broken by PR #3737
Benchmarks only run on `main` branch, so, the pre-commit tests didn't
catch these.
2023-03-03 18:47:57 +01:00
Christian Schwarz
66a5159511 fix: compaction: no index upload scheduled if no on-demand downloads
Commit

    0cf7fd0fb8
    Compaction with on-demand download (#3598)

introduced a subtle bug: if we don't have to do on-demand downloads,
we only take one ROUND in fn compact() and exit early.
Thereby, we miss scheduling the index part upload for any layers
created by fn compact_inner().

Before that commit, we didn't have this problem.
So, this patch fixes it.

Since no regression test caught this, I went ahead and extended the
timeline size tests to assert that, if remote storage is configured,
1. pageserver_remote_physical_size matches the other physical sizes
2. file sizes reported by the layer map info endpoint match the other
   physical size metrics

Without the pageserver code fix, the regression test would
fail at the physical size assertion, complaining that
any of the resident physical size != remote physical size metric
50790400.0 != 18399232.0
I figured out what the problem is by comparing the remote storage
and local directories like so, and noticed that the image layer
in the local directory wasn't present on the remote side.
It's size was exactly the difference
    50790400.0 - 18399232.0  =32391168.0

fixes https://github.com/neondatabase/neon/issues/3738
2023-03-03 16:11:54 +01:00
Christian Schwarz
d1a0a907ff tests: use parse_metrics everywhere (#3737)
- use parse_metrics() in all places where we parse Prometheus metrics
- query_all: make `filter` argument optional
- encourage using properly parsed, typed metrics by changing get_metrics()
  to return already-parsed metrics. The new get_metric_str() method,
  like in the Safekeeper type, returns the raw text response.
2023-03-03 14:53:27 +01:00
Christian Schwarz
1b780fa752 timeline_checkpoint_handler: add span with tenant and timeline id
Before this patch, the logs written by freeze_and_flush() and compact()
didn't have any span, which made the test logs annoying to read.
2023-03-03 12:10:24 +01:00
Christian Schwarz
38022ff11c gc: only decrement resident size if GC'd layer is resident
Before this patch, GC would call PersistentLayer::delete()
on every GC'ed layer.
RemoteLayer::delete() returned Ok(()) unconditionally.
GC would then proceed by decrementing the resident size metric,
even though the layer is a RemoteLayer.

This patch makes the following changes:
- Rename PersistentLayer::delete() to delete_resident_layer_file().
  That name is unambiguous.
- Make RemoteLayer::delete_resident_layer_file return an Err().
  We would have uncovered this bug if we had done that from the start.
- Change GC / Timeline::delete_historic_layer check whether
  the layer is remote or not, and only call delete_resident_layer_file()
  if it's not remote. This brings us in line with how eviction does it.
- Add a regression test.

fixes https://github.com/neondatabase/neon/issues/3722
2023-03-03 12:10:24 +01:00
Christian Schwarz
1b9b9d60d4 eviction: add comment explaining resident size decrement on error
https://github.com/neondatabase/neon/issues/3722
2023-03-03 12:10:24 +01:00
Christian Schwarz
68141a924d eviction: remove needless if-let around resident size decrement
The branch was always taken at runtime, so, this should not
constitute a behavioral change.

refs https://github.com/neondatabase/neon/issues/3722
2023-03-03 12:10:24 +01:00
Christian Schwarz
764d27f696 fix checkpoint_timeout serialization in TenantConf
Without this change, when actually setting this conf opt, the tenant
would become Broken next time we load it.
Why?
The serde_toml representation that persist_tenant_conf would write out
would be a TOML inline table of `secs` and `nsecs`.
But our hand-rolled TenantConf parser expects a TOML string.

I checked that all other `Duration` values in TenantConfOpt
use the humantime serialization.

Issues like this would likely be systematically prevent by
https://github.com/neondatabase/neon/issues/3682
2023-03-03 12:10:24 +01:00
Arthur Petukhovsky
b23742e09c Create /v1/debug_dump safekeepers endpoint (#3710)
Add HTTP endpoint to get full safekeeper state of all existing timelines
(all in-memory values and info about all files stored on disk).

Example:
https://gist.github.com/petuhovskiy/3cbb8f870401e9f486731d145161c286
2023-03-03 14:01:05 +03:00
Vadim Kharitonov
5e514b8465 Compile pgTAP extension 2023-03-03 10:14:38 +01:00
Vadim Kharitonov
a60f687ce2 Compile rum extension 2023-03-02 12:28:20 +01:00
sharnoff
8dae879994 Disable VM cgroup shenanigans (#3730)
As discussed - temporary, so it can unblock releasing autoscaling.
Cleaner to fully remove, then add back rather than commenting it out.
2023-03-01 23:58:43 -08:00
Shany Pozin
d19c5248c9 Add UUID header to mgmt API (#3708)
## Describe your changes

## Issue ticket number and link
#3479
## Checklist before requesting a review
- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
2023-03-01 18:09:08 +02:00
sharnoff
1360361f60 Fix missing VM cgconfig.conf (#3718)
It was being added to the wrong stage in the dockerfile. This should fix
it, and resolves an ongoing issue on staging.
2023-02-28 21:11:00 -08:00
Alexander Bayandin
000eb1b069 Bump tempfile from 3.3.0 to 3.4.0 (#3709)
Update `tempfile` crate to get rid of `remove_dir_all` dependency
Ref https://github.com/neondatabase/neon/security/dependabot/15
2023-02-27 12:44:08 +00:00
Heikki Linnakangas
f51b48fa49 Fix UNLOGGED tables.
Instead of trying to create missing files on the way, send init fork contents as
main fork from pageserver during basebackup. Add test for that. Call
put_rel_drop for init forks; previously they weren't removed. Bump
vendor/postgres to revert previous approach on Postgres side.

Co-authored-by: Arseny Sher <sher-ars@yandex.ru>

ref https://github.com/neondatabase/postgres/pull/264
ref https://github.com/neondatabase/postgres/pull/259
ref https://github.com/neondatabase/neon/issues/1222
2023-02-24 23:30:02 +04:00
Sergey Melnikov
9f906ff236 Add pageserver-2.us-east-2.aws.neon.tech (#3701) 2023-02-23 19:56:21 +01:00
Sam Kleinman
c79dd8d458 compute_ctl: support for fetching spec from control plane (#3610) 2023-02-23 13:19:39 -05:00
Vadim Kharitonov
ec4ecdd543 Enable postgres SPI extensions 2023-02-23 16:49:37 +01:00
MMeent
20a4d817ce Update vendored PostgreSQL versions to 14.7 and 15.2 (#3581)
## Describe your changes
Rebase vendored PostgreSQL onto 14.7 and 15.2

## Issue ticket number and link

#3579

## Checklist before requesting a review
- [x] I have performed a self-review of my code.
- [x] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [x] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.
    ```
The version of PostgreSQL that we use is updated to 14.7 for PostgreSQL
14 and 15.2 for PostgreSQL 15.
    ```
2023-02-23 16:10:22 +02:00
Vadim Kharitonov
5ebf7e5619 Fix pg_jsonschema and pg_graphql 2023-02-23 10:43:46 +01:00
Arseny Sher
0692fffbf3 Bump vendor/postgres to include hotfix for unlogged tables with indexes.
https://github.com/neondatabase/postgres/pull/259
https://github.com/neondatabase/postgres/pull/262
2023-02-23 01:34:59 +04:00
Vadim Kharitonov
093570af20 Compile pg_hashids extension 2023-02-22 21:00:25 +01:00
Dmitry Rodionov
eb403da814 Use debug level for successful GET http requests (#3681)
We started rather frequently scrap some apis for metadata. This includes
layer eviction tester, I believe console does that too.

It should eliminate these logs:
https://neonprod.grafana.net/goto/rr_ace1Vz?orgId=1 (Note the rate
around 2k messages per minute)
2023-02-22 22:19:05 +03:00
Vadim Kharitonov
f3ad635911 Compile pgrouting extension 2023-02-22 20:16:11 +01:00
Vadim Kharitonov
a8d7360881 Compile hypopg extension 2023-02-22 20:14:30 +01:00
Lassi Pölönen
b0311cfdeb Change the production neon-proxy-scram update strategy to RollingUpdate (#3683)
## Describe your changes
The same change in production as was done in staging by
https://github.com/neondatabase/neon/pull/3678

## Issue ticket number and link
https://github.com/neondatabase/neon/issues/3333
2023-02-22 20:15:37 +02:00
122 changed files with 5145 additions and 3435 deletions

View File

@@ -27,6 +27,8 @@ storage:
ansible_host: i-062227ba7f119eb8c
pageserver-1.us-east-2.aws.neon.tech:
ansible_host: i-0b3ec0afab5968938
pageserver-2.us-east-2.aws.neon.tech:
ansible_host: i-0d7a1c4325e71421d
safekeepers:
hosts:

View File

@@ -1,6 +1,22 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -1,6 +1,22 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -1,6 +1,22 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -1,6 +1,22 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -551,6 +551,48 @@ jobs:
- name: Cleanup ECR folder
run: rm -rf ~/.ecr
neon-image-depot:
# For testing this will run side-by-side for a few merges.
# This action is not really optimized yet, but gets the job done
runs-on: [ self-hosted, gen3, small ]
needs: [ tag ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
permissions:
contents: read
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 0
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Install Crane & ECR helper
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Build and push
uses: depot/build-push-action@v1
with:
# if no depot.json file is at the root of your repo, you must specify the project id
project: nrdv0s4kcs
push: true
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]

72
Cargo.lock generated
View File

@@ -851,9 +851,11 @@ dependencies = [
"futures",
"hyper",
"notify",
"num_cpus",
"opentelemetry",
"postgres",
"regex",
"reqwest",
"serde",
"serde_json",
"tar",
@@ -912,6 +914,7 @@ dependencies = [
"once_cell",
"pageserver_api",
"postgres",
"postgres_backend",
"postgres_connection",
"regex",
"reqwest",
@@ -2453,6 +2456,7 @@ dependencies = [
"postgres",
"postgres-protocol",
"postgres-types",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
@@ -2675,6 +2679,28 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "postgres_backend"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"futures",
"once_cell",
"pq_proto",
"rustls",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tracing",
"workspace_hack",
]
[[package]]
name = "postgres_connection"
version = "0.1.0"
@@ -2722,7 +2748,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
name = "pq_proto"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2897,6 +2923,7 @@ dependencies = [
"opentelemetry",
"parking_lot",
"pin-project-lite",
"postgres_backend",
"pq_proto",
"prometheus",
"rand",
@@ -3066,15 +3093,6 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.14"
@@ -3285,15 +3303,6 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-split"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
dependencies = [
"rustls",
]
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3315,6 +3324,7 @@ dependencies = [
"async-trait",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"const_format",
"crc32c",
@@ -3324,11 +3334,11 @@ dependencies = [
"humantime",
"hyper",
"metrics",
"nix",
"once_cell",
"parking_lot",
"postgres",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"pq_proto",
"regex",
@@ -3848,16 +3858,15 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.3.0"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
dependencies = [
"cfg-if",
"fastrand",
"libc",
"redox_syscall",
"remove_dir_all",
"winapi",
"rustix",
"windows-sys 0.42.0",
]
[[package]]
@@ -4514,7 +4523,7 @@ dependencies = [
"byteorder",
"bytes",
"criterion",
"git-version",
"futures",
"heapless",
"hex",
"hex-literal",
@@ -4523,12 +4532,8 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pq_proto",
"rand",
"routerify",
"rustls",
"rustls-pemfile",
"rustls-split",
"sentry",
"serde",
"serde_json",
@@ -4539,10 +4544,10 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"tracing-subscriber",
"url",
"uuid",
"workspace_hack",
]
@@ -4842,15 +4847,19 @@ name = "workspace_hack"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"crossbeam-utils",
"digest",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
"futures-sink",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
@@ -4875,6 +4884,7 @@ dependencies = [
"socket2",
"syn",
"tokio",
"tokio-rustls",
"tokio-util",
"tonic",
"tower",

View File

@@ -64,6 +64,7 @@ md5 = "0.7.0"
memoffset = "0.8"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.18.0"
@@ -133,6 +134,7 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
@@ -150,7 +152,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
criterion = "0.4"
rcgen = "0.10"
rstest = "0.16"
tempfile = "3.2"
tempfile = "3.4"
tonic-build = "0.8"
# This is only needed for proxy's tests.

View File

@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY . .
COPY --chown=nonroot . .
# Show build caching stats to check if it was used in the end.
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.

View File

@@ -32,11 +32,15 @@ RUN cd postgres && \
make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s -C src/include install && \
make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s -C src/interfaces/libpq install && \
# Enable some of contrib extensions
echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/bloom.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/intagg.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/earthdistance.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/insert_username.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/intagg.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/moddatetime.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
#########################################################################################
@@ -60,10 +64,11 @@ RUN wget https://gitlab.com/Oslandia/SFCGAL/-/archive/v1.3.10/SFCGAL-v1.3.10.tar
DESTDIR=/sfcgal make install -j $(getconf _NPROCESSORS_ONLN) && \
make clean && cp -R /sfcgal/* /
ENV PATH "/usr/local/pgsql/bin:$PATH"
RUN wget https://download.osgeo.org/postgis/source/postgis-3.3.2.tar.gz -O postgis.tar.gz && \
mkdir postgis-src && cd postgis-src && tar xvzf ../postgis.tar.gz --strip-components=1 -C . && \
./autogen.sh && \
export PATH="/usr/local/pgsql/bin:$PATH" && \
./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
cd extensions/postgis && \
@@ -77,6 +82,15 @@ RUN wget https://download.osgeo.org/postgis/source/postgis-3.3.2.tar.gz -O postg
echo 'trusted = true' >> /usr/local/pgsql/share/extension/address_standardizer.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/address_standardizer_data_us.control
RUN wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouting.tar.gz && \
mkdir pgrouting-src && cd pgrouting-src && tar xvzf ../pgrouting.tar.gz --strip-components=1 -C . && \
mkdir build && \
cd build && \
cmake .. && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrouting.control
#########################################################################################
#
# Layer "plv8-build"
@@ -181,6 +195,111 @@ RUN wget https://github.com/michelp/pgjwt/archive/9742dab1b2f297ad3811120db7b214
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgjwt.control
#########################################################################################
#
# Layer "hypopg-pg-build"
# compile hypopg extension
#
#########################################################################################
FROM build-deps AS hypopg-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/HypoPG/hypopg/archive/refs/tags/1.3.1.tar.gz -O hypopg.tar.gz && \
mkdir hypopg-src && cd hypopg-src && tar xvzf ../hypopg.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hypopg.control
#########################################################################################
#
# Layer "pg-hashids-pg-build"
# compile pg_hashids extension
#
#########################################################################################
FROM build-deps AS pg-hashids-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
mkdir pg_hashids-src && cd pg_hashids-src && tar xvzf ../pg_hashids.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_hashids.control
#########################################################################################
#
# Layer "rum-pg-build"
# compile rum extension
#
#########################################################################################
FROM build-deps AS rum-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/postgrespro/rum/archive/refs/tags/1.3.13.tar.gz -O rum.tar.gz && \
mkdir rum-src && cd rum-src && tar xvzf ../rum.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/rum.control
#########################################################################################
#
# Layer "pgtap-pg-build"
# compile pgTAP extension
#
#########################################################################################
FROM build-deps AS pgtap-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgtap.tar.gz && \
mkdir pgtap-src && cd pgtap-src && tar xvzf ../pgtap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
#########################################################################################
#
# Layer "prefix-pg-build"
# compile Prefix extension
#
#########################################################################################
FROM build-deps AS prefix-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
#########################################################################################
#
# Layer "hll-pg-build"
# compile hll extension
#
#########################################################################################
FROM build-deps AS hll-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
#########################################################################################
#
# Layer "plpgsql-check-pg-build"
# compile plpgsql_check extension
#
#########################################################################################
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
#########################################################################################
#
# Layer "rust extensions"
@@ -204,7 +323,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
cargo install --locked --version 0.7.3 cargo-pgx && \
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
@@ -218,8 +337,10 @@ USER root
FROM rust-extensions-build AS pg-jsonschema-pg-build
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
cd pg_jsonschema && \
# there is no release tag yet, but we need it due to the superuser fix in the control file
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
@@ -232,11 +353,32 @@ RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.
FROM rust-extensions-build AS pg-graphql-pg-build
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
cd pg_graphql && \
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
# same 1.1 version we've used before.
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
cd pg_graphql && \
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
#########################################################################################
#
# Layer "pg-tiktoken-build"
# Compile "pg_tiktoken" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-tiktoken-pg-build
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
cd pg_tiktoken && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -254,11 +396,23 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon \
-s install && \
make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon_utils \
-s install
#########################################################################################
@@ -313,7 +467,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU)
# libicu67, locales for collations (including ICU and plpgsql_check)
# libossp-uuid16 for extension ossp-uuid
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2

View File

@@ -10,23 +10,16 @@ RUN set -e \
&& rm -f /etc/inittab \
&& touch /etc/inittab
ADD vm-cgconfig.conf /etc/cgconfig.conf
RUN set -e \
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart --cgroup=neon-postgres'" >> /etc/inittab
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run apt update and adduser
# Temporarily set user back to root so we can run adduser
USER root
RUN apt update && \
apt install --no-install-recommends -y \
cgroup-tools
RUN adduser vm-informant --disabled-password --no-create-home
USER postgres
COPY --from=informant /etc/inittab /etc/inittab
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]

View File

@@ -133,6 +133,11 @@ neon-pg-ext-%: postgres-%
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
.PHONY: neon-pg-ext-clean-%
neon-pg-ext-clean-%:
@@ -145,6 +150,9 @@ neon-pg-ext-clean-%:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \

View File

@@ -11,12 +11,14 @@ clap.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
tar.workspace = true
reqwest = { workspace = true, features = ["json"] }
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tokio-postgres.workspace = true
tracing.workspace = true

View File

@@ -65,6 +65,9 @@ fn main() -> Result<()> {
let spec = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
let compute_id = matches.get_one::<String>("compute-id");
let control_plane_uri = matches.get_one::<String>("control-plane-uri");
// Try to use just 'postgres' if no path is provided
let pgbin = matches.get_one::<String>("pgbin").unwrap();
@@ -77,8 +80,27 @@ fn main() -> Result<()> {
let path = Path::new(sp);
let file = File::open(path)?;
serde_json::from_reader(file)?
} else if let Some(id) = compute_id {
if let Some(cp_base) = control_plane_uri {
let cp_uri = format!("{cp_base}/management/api/v1/{id}/spec");
let jwt: String = match std::env::var("NEON_CONSOLE_JWT") {
Ok(v) => v,
Err(_) => "".to_string(),
};
reqwest::blocking::Client::new()
.get(cp_uri)
.header("Authorization", jwt)
.send()?
.json()?
} else {
panic!(
"must specify --control-plane-uri \"{:#?}\" and --compute-id \"{:#?}\"",
control_plane_uri, compute_id
);
}
} else {
panic!("cluster spec should be provided via --spec or --spec-path argument");
panic!("compute spec should be provided via --spec or --spec-path argument");
}
}
};
@@ -227,6 +249,18 @@ fn cli() -> clap::Command {
.long("spec-path")
.value_name("SPEC_PATH"),
)
.arg(
Arg::new("compute-id")
.short('i')
.long("compute-id")
.value_name("COMPUTE_ID"),
)
.arg(
Arg::new("control-plane-uri")
.short('p')
.long("control-plane-uri")
.value_name("CONTROL_PLANE"),
)
}
#[test]

View File

@@ -25,6 +25,7 @@ use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use serde::{Serialize, Serializer};
use tokio_postgres;
use tracing::{info, instrument, warn};
use crate::checker::create_writability_check_data;
@@ -284,6 +285,7 @@ impl ComputeNode {
handle_role_deletions(self, &mut client)?;
handle_grants(self, &mut client)?;
create_writability_check_data(&mut client)?;
handle_extensions(&self.spec, &mut client)?;
// 'Close' connection
drop(client);
@@ -400,4 +402,43 @@ impl ComputeNode {
Ok(())
}
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.simple_query(
"SELECT
row_to_json(pg_stat_statements)
FROM
pg_stat_statements
WHERE
userid != 'cloud_admin'::regrole::oid
ORDER BY
(mean_exec_time + mean_plan_time) DESC
LIMIT 100",
)
.await;
if let Ok(raw_rows) = result {
for message in raw_rows.iter() {
if let postgres::SimpleQueryMessage::Row(row) = message {
if let Some(json) = row.get(0) {
result_rows.push(json.to_string());
}
}
}
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
} else {
"{{\"pg_stat_statements\": []}}".to_string()
}
}
}

View File

@@ -7,6 +7,7 @@ use crate::compute::ComputeNode;
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use num_cpus;
use serde_json;
use tracing::{error, info};
use tracing_utils::http::OtelName;
@@ -33,6 +34,13 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
}
// Collect Postgres current usage insights
(&Method::GET, "/insights") => {
info!("serving /insights GET request");
let insights = compute.collect_insights().await;
Response::new(Body::from(insights))
}
(&Method::POST, "/check_writability") => {
info!("serving /check_writability POST request");
let res = crate::checker::check_writability(compute).await;
@@ -42,6 +50,17 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Response::new(Body::from(
serde_json::json!({
"num_cpus": num_cpus,
})
.to_string(),
))
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));

View File

@@ -10,12 +10,12 @@ paths:
/status:
get:
tags:
- "info"
- Info
summary: Get compute node internal status
description: ""
operationId: getComputeStatus
responses:
"200":
200:
description: ComputeState
content:
application/json:
@@ -25,27 +25,58 @@ paths:
/metrics.json:
get:
tags:
- "info"
- Info
summary: Get compute node startup metrics in JSON format
description: ""
operationId: getComputeMetricsJSON
responses:
"200":
200:
description: ComputeMetrics
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeMetrics"
/insights:
get:
tags:
- Info
summary: Get current compute insights in JSON format
description: |
Note, that this doesn't include any historical data
operationId: getComputeInsights
responses:
200:
description: Compute insights
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeInsights"
/info:
get:
tags:
- "info"
summary: Get info about the compute Pod/VM
description: ""
operationId: getInfo
responses:
"200":
description: Info
content:
application/json:
schema:
$ref: "#/components/schemas/Info"
/check_writability:
post:
tags:
- "check"
- Check
summary: Check that we can write new data on this compute
description: ""
operationId: checkComputeWritability
responses:
"200":
200:
description: Check result
content:
text/plain:
@@ -80,6 +111,15 @@ components:
total_startup_ms:
type: integer
Info:
type: object
description: Information about VM/Pod
required:
- num_cpus
properties:
num_cpus:
type: integer
ComputeState:
type: object
required:
@@ -96,6 +136,15 @@ components:
type: string
description: Text of the error during compute startup, if any
ComputeInsights:
type: object
properties:
pg_stat_statements:
description: Contains raw output from pg_stat_statements in JSON format
type: array
items:
type: object
ComputeStatus:
type: string
enum:

View File

@@ -47,12 +47,23 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Escape a string for including it in a SQL literal
fn escape_literal(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
/// Escape a string so that it can be used in postgresql.conf.
/// Same as escape_literal, currently.
fn escape_conf_value(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
impl GenericOption {
/// Represent `GenericOption` as SQL statement parameter.
pub fn to_pg_option(&self) -> String {
if let Some(val) = &self.value {
match self.vartype.as_ref() {
"string" => format!("{} '{}'", self.name, val),
"string" => format!("{} '{}'", self.name, escape_literal(val)),
_ => format!("{} {}", self.name, val),
}
} else {
@@ -63,6 +74,8 @@ impl GenericOption {
/// Represent `GenericOption` as configuration option.
pub fn to_pg_setting(&self) -> String {
if let Some(val) = &self.value {
// TODO: check in the console DB that we don't have these settings
// set for any non-deleted project and drop this override.
let name = match self.name.as_str() {
"safekeepers" => "neon.safekeepers",
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
@@ -71,7 +84,7 @@ impl GenericOption {
};
match self.vartype.as_ref() {
"string" => format!("{} = '{}'", name, val),
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
_ => format!("{} = {}", name, val),
}
} else {
@@ -107,6 +120,7 @@ impl PgOptionsSerialize for GenericOptions {
.map(|op| op.to_pg_setting())
.collect::<Vec<String>>()
.join("\n")
+ "\n" // newline after last setting
} else {
"".to_string()
}

View File

@@ -515,3 +515,18 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
Ok(())
}
/// Create required system extensions
#[instrument(skip_all)]
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("pg_stat_statements") {
// Create extension only if this compute really needs it
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
info!("creating system extensions with query: {}", query);
client.simple_query(query)?;
}
}
Ok(())
}

View File

@@ -178,6 +178,11 @@
"name": "neon.pageserver_connstring",
"value": "host=127.0.0.1 port=6400",
"vartype": "string"
},
{
"name": "test.escaping",
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
"vartype": "string"
}
]
},

View File

@@ -28,7 +28,30 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
r#"fsync = off
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on
shared_buffers = 32768
port = 55432
max_connections = 100
max_wal_senders = 10
listen_addresses = '0.0.0.0'
wal_sender_timeout = 0
password_encryption = md5
maintenance_work_mem = 65536
max_parallel_workers = 8
max_worker_processes = 8
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
max_replication_slots = 10
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
shared_preload_libraries = 'neon'
synchronous_standby_names = 'walproposer'
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
"#
);
}

View File

@@ -24,6 +24,7 @@ url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true

View File

@@ -17,6 +17,7 @@ use pageserver_api::{
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
};
use postgres_backend::AuthType;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -30,7 +31,6 @@ use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
project_git_version,
};

View File

@@ -11,10 +11,10 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use postgres_backend::AuthType;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};

View File

@@ -5,6 +5,7 @@
use anyhow::{bail, ensure, Context};
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -19,7 +20,6 @@ use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
};
use crate::safekeeper::SafekeeperNode;

View File

@@ -11,6 +11,7 @@ use anyhow::{bail, Context};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
};
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
@@ -20,7 +21,6 @@ use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::{background_process, local_env::LocalEnv};

View File

@@ -29,6 +29,41 @@ These components should not have access to the private key and may only get toke
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
There is currently no way to rotate the key without bringing down all components.
### Token format
The JWT tokens in Neon use RSA as the algorithm. Example:
Header:
```
{
"alg": "RS512", # RS256, RS384, or RS512
"typ": "JWT"
}
```
Payload:
```
{
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
}
```
Meanings of scope:
"tenant": Provides access to all data for a specific tenant
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
Should only be used e.g. for status check/tenant creation/list.
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper.
### CLI
CLI generates a key pair during call to `neon_local init` with the following commands:

View File

@@ -98,6 +98,15 @@ impl RelTag {
name
}
pub fn with_forknum(&self, forknum: u8) -> Self {
RelTag {
forknum,
spcnode: self.spcnode,
dbnode: self.dbnode,
relnode: self.relnode,
}
}
}
///

View File

@@ -0,0 +1,26 @@
[package]
name = "postgres_backend"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -0,0 +1,910 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Io(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend {
framed: MaybeWriteOnly,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
self.framed.flush().await
}
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message_noflush(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message_noflush(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message_noflush(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message_noflush(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
impl PostgresBackendReader {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -0,0 +1,139 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl Handler for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}

View File

@@ -5,8 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
bytes.workspace = true
byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true

244
libs/pq_proto/src/framed.rs Normal file
View File

@@ -0,0 +1,244 @@
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
//! the async stream based on (and buffered with) BytesMut. All functions are
//! cancellation safe.
//!
//! It is similar to what tokio_util::codec::Framed with appropriate codec
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
//! separately without using split from futures::stream::StreamExt (which
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
//! instead. Plus we customize error messages more than a single type for all io
//! calls.
//!
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
const INITIAL_CAPACITY: usize = 8 * 1024;
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl ConnectionError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Io(io) => io,
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
}
}
}
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<S> Framed<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
}
/// Return new Framed with stream type transformed by async f, for TLS
/// upgrade.
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
where
F: FnOnce(S) -> Fut,
Fut: Future<Output = Result<S2, E>>,
{
let stream = f(self.stream).await?;
Ok(Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
})
}
}
impl<S: AsyncRead + Unpin> Framed<S> {
pub async fn read_startup_message(
&mut self,
) -> Result<Option<FeStartupPacket>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
}
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
impl<S: AsyncWrite + Unpin> Framed<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
/// Split into owned read and write parts. Beware of potential issues with
/// using halves in different tasks on TLS stream:
/// https://github.com/tokio-rs/tls/issues/40
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
let (read_half, write_half) = tokio::io::split(self.stream);
let reader = FramedReader {
stream: read_half,
read_buf: self.read_buf,
};
let writer = FramedWriter {
stream: write_half,
write_buf: self.write_buf,
};
(reader, writer)
}
/// Join read and write parts back.
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
Self {
stream: reader.stream.unsplit(writer.stream),
read_buf: reader.read_buf,
write_buf: writer.write_buf,
}
}
}
/// Read-only version of `Framed`.
pub struct FramedReader<S> {
stream: ReadHalf<S>,
read_buf: BytesMut,
}
impl<S: AsyncRead + Unpin> FramedReader<S> {
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
/// Write-only version of `Framed`.
pub struct FramedWriter<S> {
stream: WriteHalf<S>,
write_buf: BytesMut,
}
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
/// don't have remaining data in the buffer. This function is cancellation safe:
/// you can drop future which is not yet complete and finalize reading message
/// with the next call.
///
/// Parametrized to allow reading startup or usual message, having different
/// format.
async fn read_message<S: AsyncRead + Unpin, M, P>(
stream: &mut S,
read_buf: &mut BytesMut,
parse: P,
) -> Result<Option<M>, ConnectionError>
where
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
{
loop {
if let Some(msg) = parse(read_buf)? {
return Ok(Some(msg));
}
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
read_buf.reserve(1);
if stream.read_buf(read_buf).await? == 0 {
if read_buf.has_remaining() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"EOF with unprocessed data in the buffer",
)
.into());
} else {
return Ok(None); // clean EOF
}
}
}
}
async fn flush<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
while write_buf.has_remaining() {
let bytes_written = stream.write(write_buf.chunk()).await?;
if bytes_written == 0 {
return Err(io::Error::new(
ErrorKind::WriteZero,
"failed to write message",
));
}
// The advanced part will be garbage collected, likely during shifting
// data left on next attempt to write to buffer when free space is not
// enough.
write_buf.advance(bytes_written);
}
write_buf.clear();
stream.flush().await
}
async fn shutdown<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
flush(stream, write_buf).await?;
stream.shutdown().await
}

View File

@@ -2,24 +2,18 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
// Tools for calling certain async methods in sync contexts.
pub mod sync;
pub mod framed;
use anyhow::{ensure, Context, Result};
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt,
future::Future,
io::{self, Cursor},
str,
fmt, io, str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -31,7 +25,6 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
@@ -191,260 +184,207 @@ pub struct FeExecuteMessage {
#[derive(Debug)]
pub struct FeCloseMessage;
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
/// An error occured during connection being open.
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
/// IO error during writing to or reading from the connection socket.
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
pub enum ProtocolError {
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
BadMessage(String),
}
impl From<anyhow::Error> for ConnectionError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
impl ProtocolError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
io::Error::new(io::ErrorKind::Other, self.to_string())
}
}
impl FeMessage {
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
///
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case to save the repeated calls.
///
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
// length field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => {
return Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
)))
}
})
}
}
}
impl FeStartupPacket {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// need at least 4 bytes with packet len
if buf.len() < 4 {
let to_read = 4 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
"invalid message length {len}"
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
.to_owned(),
));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
params.insert(name.to_owned(), value.to_owned());
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
};
Ok(Some(message))
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
@@ -452,55 +392,82 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
return Err(ProtocolError::BadMessage(
"Parse message is malformed, nparams missing".to_string(),
));
}
let nparams = buf.get_i16();
ensure!(nparams == 0, "query params not implemented");
if nparams != 0 {
return Err(ProtocolError::BadMessage(
"query params not implemented".to_string(),
));
}
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
if kind != b'S' {
return Err(ProtocolError::BadMessage(
"only prepared statemement Describe is implemented".to_string(),
));
}
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
return Err(ProtocolError::BadMessage(
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
));
}
let maxrows = buf.get_i32();
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
if maxrows != 0 {
return Err(ProtocolError::BadMessage(
"row limit in Execute message not implemented".to_string(),
));
}
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
ensure!(portal_name.is_empty(), "named portals not implemented");
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
@@ -529,6 +496,7 @@ pub enum BeMessage<'a> {
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
// None errcode means internal_error will be sent.
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
@@ -559,6 +527,11 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -637,7 +610,7 @@ impl RowDescriptor<'_> {
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64,
pub wal_end: u64, // current end of WAL on the server
pub timestamp: i64,
pub data: &'a [u8],
}
@@ -677,12 +650,11 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
return Err(ProtocolError::BadMessage(
"string contains embedded null".to_owned(),
));
}
buf.put_slice(bytes);
@@ -690,22 +662,27 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
Ok(())
}
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
/// Read cstring from buf, advancing it.
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
let pos = buf
.iter()
.position(|x| *x == 0)
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
let result = buf.split_to(pos);
buf.advance(1); // drop the null terminator
Ok(result)
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -750,7 +727,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -854,7 +831,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -877,7 +854,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -932,7 +909,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -999,7 +976,7 @@ impl ReplicationFeedback {
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
pub fn serialize(&self, buf: &mut BytesMut) {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
@@ -1024,7 +1001,6 @@ impl ReplicationFeedback {
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
@@ -1092,7 +1068,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data).unwrap();
rf.serialize(&mut data);
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
@@ -1107,7 +1083,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data).unwrap();
rf.serialize(&mut data);
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
@@ -1149,15 +1125,6 @@ mod tests {
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {

View File

@@ -1,179 +0,0 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use pq_proto::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -12,41 +12,36 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
nix.workspace = true
signal-hook.workspace = true
rand.workspace = true
jsonwebtoken.workspace = true
hex = { workspace = true, features = ["serde"] }
rustls.workspace = true
rustls-split.workspace = true
git-version.workspace = true
serde_with.workspace = true
once_cell.workspace = true
strum.workspace = true
strum_macros.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
metrics.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
url.workspace = true
[dev-dependencies]
byteorder.workspace = true
bytes.workspace = true
criterion.workspace = true
hex-literal.workspace = true
tempfile.workspace = true
criterion.workspace = true
rustls-pemfile.workspace = true
[[bench]]
name = "benchmarks"

View File

@@ -9,16 +9,28 @@ use std::path::Path;
use anyhow::Result;
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
decode, encode, Algorithm, Algorithm::*, DecodingKey, EncodingKey, Header, TokenData,
Validation,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use crate::id::TenantId;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
/// Algorithms accepted during validation.
///
/// Accept all RSA-based algorithms. We pass this list to jsonwebtoken::decode,
/// which checks that the algorithm in the token is one of these.
///
/// XXX: It also fails the validation if there are any algorithms in this list that belong
/// to different family than the token's algorithm. In other words, we can *not* list any
/// non-RSA algorithms here, or the validation always fails with InvalidAlgorithm error.
const ACCEPTED_ALGORITHMS: &[Algorithm] = &[RS256, RS384, RS512];
#[derive(Debug, Serialize, Deserialize, Clone)]
/// Algorithm to use when generating a new token in [`encode_from_key_file`]
const ENCODE_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
@@ -33,8 +45,9 @@ pub enum Scope {
SafekeeperData,
}
/// JWT payload. See docs/authentication.md for the format
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
@@ -55,7 +68,8 @@ pub struct JwtAuth {
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::new(JWT_ALGORITHM);
let mut validation = Validation::default();
validation.algorithms = ACCEPTED_ALGORITHMS.into();
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
@@ -86,5 +100,113 @@ impl std::fmt::Debug for JwtAuth {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
Ok(encode(&Header::new(ENCODE_ALGORITHM), claims, &key)?)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// generated with:
//
// openssl genpkey -algorithm rsa -out storage-auth-priv.pem
// openssl pkey -in storage-auth-priv.pem -pubout -out storage-auth-pub.pem
const TEST_PUB_KEY_RSA: &[u8] = br#"
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAy6OZ+/kQXcueVJA/KTzO
v4ljxylc/Kcb0sXWuXg1GB8k3nDA1gK66LFYToH0aTnqrnqG32Vu6wrhwuvqsZA7
jQvP0ZePAbWhpEqho7EpNunDPcxZ/XDy5TQlB1P58F9I3lkJXDC+DsHYLuuzwhAv
vo2MtWRdYlVHblCVLyZtANHhUMp2HUhgjHnJh5UrLIKOl4doCBxkM3rK0wjKsNCt
M92PCR6S9rvYzldfeAYFNppBkEQrXt2CgUqZ4KaS4LXtjTRUJxljijA4HWffhxsr
euRu3ufq8kVqie7fum0rdZZSkONmce0V0LesQ4aE2jB+2Sn48h6jb4dLXGWdq8TV
wQIDAQAB
-----END PUBLIC KEY-----
"#;
const TEST_PRIV_KEY_RSA: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDLo5n7+RBdy55U
kD8pPM6/iWPHKVz8pxvSxda5eDUYHyTecMDWArrosVhOgfRpOequeobfZW7rCuHC
6+qxkDuNC8/Rl48BtaGkSqGjsSk26cM9zFn9cPLlNCUHU/nwX0jeWQlcML4Owdgu
67PCEC++jYy1ZF1iVUduUJUvJm0A0eFQynYdSGCMecmHlSssgo6Xh2gIHGQzesrT
CMqw0K0z3Y8JHpL2u9jOV194BgU2mkGQRCte3YKBSpngppLgte2NNFQnGWOKMDgd
Z9+HGyt65G7e5+ryRWqJ7t+6bSt1llKQ42Zx7RXQt6xDhoTaMH7ZKfjyHqNvh0tc
ZZ2rxNXBAgMBAAECggEAVz3u4Wlx3o02dsoZlSQs+xf0PEX3RXKeU+1YMbtTG9Nz
6yxpIQaoZrpbt76rJE2gwkFR+PEu1NmjoOuLb6j4KlQuI4AHz1auOoGSwFtM6e66
K4aZ4x95oEJ3vqz2fkmEIWYJwYpMUmwvnuJx76kZm0xvROMLsu4QHS2+zCVtO5Tr
hvS05IMVuZ2TdQBZw0+JaFdwXbgDjQnQGY5n9MoTWSx1a4s/FF4Eby65BbDutcpn
Vt3jQAOmO1X2kbPeWSGuPJRzyUs7Kg8qfeglBIR3ppGP3vPYAdWX+ho00bmsVkSp
Q8vjul6C3WiM+kjwDxotHSDgbl/xldAl7OqPh0bfAQKBgQDnycXuq14Vg8nZvyn9
rTnvucO8RBz5P6G+FZ+44cAS2x79+85onARmMnm+9MKYLSMo8fOvsK034NDI68XM
04QQ/vlfouvFklMTGJIurgEImTZbGCmlMYCvFyIxaEWixon8OpeI4rFe4Hmbiijh
PxhxWg221AwvBS2sco8J/ylEkQKBgQDg6Rh2QYb/j0Wou1rJPbuy3NhHofd5Rq35
4YV3f2lfVYcPrgRhwe3T9SVII7Dx8LfwzsX5TAlf48ESlI3Dzv40uOCDM+xdtBRI
r96SfSm+jup6gsXU3AsdNkrRK3HoOG9Z/TkrUp213QAIlVnvIx65l4ckFMlpnPJ0
lo1LDXZWMQKBgFArzjZ7N5OhfdO+9zszC3MLgdRAivT7OWqR+CjujIz5FYMr8Xzl
WfAvTUTrS9Nu6VZkObFvHrrRG+YjBsuN7YQjbQXTSFGSBwH34bgbn2fl9pMTjHQC
50uoaL9GHa/rlBaV/YvvPQJgCi/uXa1rMX0jdNLkDULGO8IF7cu7Yf7BAoGBAIUU
J29BkpmAst0GDs/ogTlyR18LTR0rXyHt+UUd1MGeH859TwZw80JpWWf4BmkB4DTS
hH3gKePdJY7S65ci0XNsuRupC4DeXuorde0DtkGU2tUmr9wlX0Ynq9lcdYfMbMa4
eK1TsxG69JwfkxlWlIWITWRiEFM3lJa7xlrUWmLhAoGAFpKWF/hn4zYg3seU9gai
EYHKSbhxA4mRb+F0/9IlCBPMCqFrL5yftUsYIh2XFKn8+QhO97Nmk8wJSK6TzQ5t
ZaSRmgySrUUhx4nZ/MgqWCFv8VUbLM5MBzwxPKhXkSTfR4z2vLYLJwVY7Tb4kZtp
8ismApXVGHpOCstzikV9W7k=
-----END PRIVATE KEY-----
"#;
#[test]
fn test_decode() -> Result<(), anyhow::Error> {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
// Here are tokens containing the following payload, signed using TEST_PRIV_KEY_RSA
// using RS512, RS384 and RS256 algorithms:
//
// ```
// {
// "scope": "tenant",
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
// "iss": "neon.controlplane",
// "exp": 1709200879,
// "iat": 1678442479
// }
// ```
//
// These were encoded with the online debugger at https://jwt.io
//
let encoded_rs512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.QmqfteDQmDGoxQ5EFkasbt35Lx0W0Nh63muQnYZvFq93DSh4ZbOG9Mc4yaiXZoiS5HgeKtFKv3mbWkDqjz3En06aY17hWwguBtAsGASX48lYeCPADYGlGAuaWnOnVRwe3iiOC7tvPFvwX_45S84X73sNUXyUiXv6nLdcDqVXudtNrGST_DnZDnjuUJX11w7sebtKqQQ8l9-iGHiXOl5yevpMCoB1OcTWcT6DfDtffoNuMHDC3fyhmEGG5oKAt1qBybqAIiyC9-UBAowRZXhdfxrzUl-I9jzKWvk85c5ulhVRwbPeP6TTTlPKwFzBNHg1i2U-1GONew5osQ3aoptwsA";
let encoded_rs384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.qqk4nkxKzOJP38c_g57_w_SfdQVmCsDT_bsLmdFj_N6LIB22gr6U6_P_5mvk3pIAsp0VCTDwPrCU908TxqjibEkwvQoJwbogHamSGHpD7eJBxGblSnA-Nr3MlEMxpFtec8QokSm6C5mH7DoBYjB2xzeOlxAmpR2GAzInKiMkU4kZ_OcqqrmVcMXY_6VnbxZWMekuw56zE1-PP_qNF1HvYOH-P08ONP8qdo5UPtBG7QBEFlCqZXJZCFihQaI4Vzil9rDuZGCm3I7xQJ8-yh1PX3BTbGo8EzqLdRyBeTpr08UTuRbp_MJDWevHpP3afvJetAItqZXIoZQrbJjcByHqKw";
let encoded_rs256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.dF2N9KXG8ftFKHYbd5jQtXMQqv0Ej8FISGp1b_dmqOCotXj5S1y2AWjwyB_EXHM77JXfbEoJPAPrFFBNfd8cWtkCSTvpxWoHaecGzegDFGv5ZSc5AECFV1Daahc3PI3jii9wEiGkFOiwiBNfZ5INomOAsV--XXxlqIwKbTcgSYI7lrOTfecXAbAHiMKQlQYiIBSGnytRCgafhRkyGzPAL8ismthFJ9RHfeejyskht-9GbVHURw02bUyijuHEulpf9eEY3ZiB28de6jnCdU7ftIYaUMaYWt0nZQGkzxKPSfSLZNy14DTOYLDS04DVstWQPqnCUW_ojg0wJETOOfo9Zw";
// Check that RS512, RS384 and RS256 tokens can all be validated
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
for encoded in [encoded_rs512, encoded_rs384, encoded_rs256] {
let claims_from_token = auth.decode(encoded)?.claims;
assert_eq!(claims_from_token, expected_claims);
}
Ok(())
}
#[test]
fn test_encode() -> Result<(), anyhow::Error> {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_RSA)?;
// decode it back
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
let decoded = auth.decode(&encoded)?;
assert_eq!(decoded.claims, claims);
Ok(())
}
}

View File

@@ -3,14 +3,14 @@ use crate::http::error;
use anyhow::{anyhow, Context};
use hyper::header::{HeaderName, AUTHORIZATION};
use hyper::http::HeaderValue;
use hyper::Method;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
use once_cell::sync::Lazy;
use routerify::ext::RequestExt;
use routerify::RequestInfo;
use routerify::{Middleware, Router, RouterBuilder, RouterService};
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
use tokio::task::JoinError;
use tracing::info;
use tracing::{self, debug, info, info_span, warn, Instrument};
use std::future::Future;
use std::net::TcpListener;
@@ -26,9 +26,83 @@ static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
.expect("failed to define a metric")
});
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
info!("{} {} {}", info.method(), info.uri().path(), res.status(),);
Ok(res)
static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
#[derive(Debug, Default, Clone)]
struct RequestId(String);
/// Adds a tracing info_span! instrumentation around the handler events,
/// logs the request start and end events for non-GET requests and non-200 responses.
///
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
/// in this type will get request info logged in the wrapping span, including the unique request ID.
///
/// There could be other ways to implement similar functionality:
///
/// * procmacros placed on top of all handler methods
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
/// and little code reduction compared to the existing approach.
///
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
/// implemented for [`RouterBuilder`].
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
///
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
/// later, in a post-response middleware.
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
/// tries to achive with its `.instrument` used in the current approach.
///
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
pub struct RequestSpan<E, R, H>(pub H)
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
impl<E, R, H> RequestSpan<E, R, H>
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
{
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
let request_id = request.context::<RequestId>().unwrap_or_default().0;
let method = request.method();
let path = request.uri().path();
let request_span = info_span!("request", %method, %path, %request_id);
let log_quietly = method == Method::GET;
async move {
if log_quietly {
debug!("Handling request");
} else {
info!("Handling request");
}
// Note that we reuse `error::handler` here and not returning and error at all,
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
//
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
match (self.0)(request).await {
Ok(response) => {
let response_status = response.status();
if log_quietly && response_status.is_success() {
debug!("Request handled, status: {response_status}");
} else {
info!("Request handled, status: {response_status}");
}
Ok(response)
}
Err(e) => Ok(error::handler(e.into()).await),
}
}
.instrument(request_span)
.await
}
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -55,10 +129,48 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
Ok(response)
}
pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
Some(request_id) => request_id
.to_str()
.expect("extract request id value")
.to_owned(),
None => {
let request_id = uuid::Uuid::new_v4();
request_id.to_string()
}
};
req.set_context(RequestId(request_id));
Ok(req)
})
}
async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
};
Ok(res)
}
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
.middleware(add_request_id_middleware())
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| {
RequestSpan(prometheus_metrics_handler).handle(r)
})
.err_handler(error::handler)
}
@@ -68,40 +180,43 @@ pub fn attach_openapi_ui(
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
router_builder
.get(spec_mount_path, move |r| {
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
.handle(r)
})
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
}).handle(r))
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
@@ -163,7 +278,7 @@ where
async move {
let headers = response.headers_mut();
if headers.contains_key(&name) {
tracing::warn!(
warn!(
"{} response already contains header {:?}",
request_info.uri(),
&name,
@@ -223,3 +338,48 @@ where
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use futures::future::poll_fn;
use hyper::service::Service;
use routerify::RequestServiceBuilder;
use std::net::{IpAddr, SocketAddr};
#[tokio::test]
async fn test_request_id_returned() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let mut req: Request<Body> = Request::default();
req.headers_mut()
.append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
assert!(header_val == "42", "response header mismatch");
}
#[tokio::test]
async fn test_request_id_empty() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let req: Request<Body> = Request::default();
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
assert_ne!(header_val, None, "response header should NOT be empty");
}
}

View File

@@ -1,7 +1,9 @@
use std::fmt::Display;
use anyhow::Context;
use bytes::Buf;
use hyper::{header, Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize, Serializer};
use super::error::ApiError;
@@ -31,3 +33,12 @@ pub fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
/// Serialize through Display trait.
pub fn display_serialize<S, F>(z: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: Display,
{
s.serialize_str(&format!("{}", z))
}

View File

@@ -13,8 +13,6 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
pub mod crashsafe;
@@ -27,9 +25,6 @@ pub mod id;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;

View File

@@ -1,485 +0,0 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
fn is_shutdown_requested(&self) -> bool {
false
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => anyhow::bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !handler.is_shutdown_requested() {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {msg:?}");
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
if let QueryError::Other(e) = &e {
if is_socket_read_timed_out(e) {
continue;
}
}
return Err(e);
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
anyhow::bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established
&& !matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
)
{
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}"
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}

View File

@@ -1,634 +0,0 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
}
.map_err(QueryError::from)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
// Handshake complete.
result?;
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
Ok(ProcessMsgResult::Continue)
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}

View File

@@ -1,206 +0,0 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

View File

@@ -1,238 +0,0 @@
use std::{
collections::HashMap,
io::{Cursor, Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use once_cell::sync::Lazy;
use utils::{
postgres_backend::{AuthType, Handler, PostgresBackend},
postgres_backend_async::QueryError,
};
fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).unwrap();
let (server_stream, _) = listener.accept().unwrap();
(server_stream, client_stream)
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
#[test]
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
// we resize the vector so doing some modifications after all
#[allow(clippy::read_zero_byte_vec)]
fn ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
const QUERY: &str = "hello world";
let client_jh = std::thread::spawn(move || {
// SSLRequest
client_sock.write_u32::<BigEndian>(8).unwrap();
client_sock.write_u32::<BigEndian>(80877103).unwrap();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'S', ssl_response);
let cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let client_config = Arc::new(cfg);
let dns_name = "localhost".try_into().unwrap();
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
conn.complete_io(&mut client_sock).unwrap();
assert!(!conn.is_handshaking());
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
// StartupMessage
stream.write_u32::<BigEndian>(9).unwrap();
stream.write_u32::<BigEndian>(196608).unwrap();
stream.write_u8(0).unwrap();
stream.flush().unwrap();
// wait for ReadyForQuery
let mut msg_buf = Vec::new();
loop {
let msg = stream.read_u8().unwrap();
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
msg_buf.resize(size as usize, 0);
stream.read_exact(&mut msg_buf).unwrap();
if msg == b'Z' {
// ReadyForQuery
break;
}
}
// Query
stream.write_u8(b'Q').unwrap();
stream
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
.unwrap();
stream.write_all(QUERY.as_ref()).unwrap();
stream.flush().unwrap();
// ReadyForQuery
let msg = stream.read_u8().unwrap();
assert_eq!(msg, b'Z');
});
struct TestHandler {
got_query: bool,
}
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
self.got_query = query_string == QUERY;
Ok(())
}
}
let mut handler = TestHandler { got_query: false };
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
pgb.run(&mut handler).unwrap();
assert!(handler.got_query);
client_jh.join().unwrap();
// TODO consider shutdown behavior
}
#[test]
fn no_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
let mut buf = BytesMut::new();
// SSLRequest
buf.put_u32(8);
buf.put_u32(80877103);
client_sock.write_all(&buf).unwrap();
buf.clear();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'N', ssl_response);
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
pgb.run(&mut handler).unwrap();
client_jh.join().unwrap();
}
#[test]
fn server_forces_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
// StartupMessage
client_sock.write_u32::<BigEndian>(9).unwrap();
client_sock.write_u32::<BigEndian>(196608).unwrap();
client_sock.write_u8(0).unwrap();
client_sock.flush().unwrap();
// ErrorResponse
assert_eq!(client_sock.read_u8().unwrap(), b'E');
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
let mut body = vec![0; len as usize];
client_sock.read_exact(&mut body).unwrap();
let mut body = Bytes::from(body);
let mut errors = HashMap::new();
loop {
let field_type = body.get_u8();
if field_type == 0u8 {
break;
}
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
let mut value = body.split_to(end_idx + 1);
assert_eq!(value[end_idx], 0u8);
value.truncate(end_idx);
let old = errors.insert(field_type, value);
assert!(old.is_none());
}
assert!(!body.has_remaining());
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
// TODO read failure
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
let res = pgb.run(&mut handler).unwrap_err();
assert_eq!("client did not connect with TLS", format!("{}", res));
client_jh.join().unwrap();
// TODO consider shutdown behavior
}

View File

@@ -37,6 +37,7 @@ num-traits.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
postgres.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
rand.workspace = true

View File

@@ -33,6 +33,7 @@ use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID};
use postgres_ffi::pg_constants::{PGDATA_SPECIAL_FILES, PGDATA_SUBDIRS, PG_HBA};
use postgres_ffi::relfile_utils::{INIT_FORKNUM, MAIN_FORKNUM};
use postgres_ffi::TransactionId;
use postgres_ffi::XLogFileName;
use postgres_ffi::PG_TLI;
@@ -190,14 +191,31 @@ where
{
self.add_dbdir(spcnode, dbnode, has_relmap_file).await?;
// Gather and send relational files in each database if full backup is requested.
if self.full_backup {
for rel in self
.timeline
.list_rels(spcnode, dbnode, self.lsn, self.ctx)
.await?
{
self.add_rel(rel).await?;
// If full backup is requested, include all relation files.
// Otherwise only include init forks of unlogged relations.
let rels = self
.timeline
.list_rels(spcnode, dbnode, self.lsn, self.ctx)
.await?;
for &rel in rels.iter() {
// Send init fork as main fork to provide well formed empty
// contents of UNLOGGED relations. Postgres copies it in
// `reinit.c` during recovery.
if rel.forknum == INIT_FORKNUM {
// I doubt we need _init fork itself, but having it at least
// serves as a marker relation is unlogged.
self.add_rel(rel, rel).await?;
self.add_rel(rel, rel.with_forknum(MAIN_FORKNUM)).await?;
continue;
}
if self.full_backup {
if rel.forknum == MAIN_FORKNUM && rels.contains(&rel.with_forknum(INIT_FORKNUM))
{
// skip this, will include it when we reach the init fork
continue;
}
self.add_rel(rel, rel).await?;
}
}
}
@@ -220,15 +238,16 @@ where
Ok(())
}
async fn add_rel(&mut self, tag: RelTag) -> anyhow::Result<()> {
/// Add contents of relfilenode `src`, naming it as `dst`.
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> anyhow::Result<()> {
let nblocks = self
.timeline
.get_rel_size(tag, self.lsn, false, self.ctx)
.get_rel_size(src, self.lsn, false, self.ctx)
.await?;
// If the relation is empty, create an empty file
if nblocks == 0 {
let file_name = tag.to_segfile_name(0);
let file_name = dst.to_segfile_name(0);
let header = new_tar_header(&file_name, 0)?;
self.ar.append(&header, &mut io::empty()).await?;
return Ok(());
@@ -244,12 +263,12 @@ where
for blknum in startblk..endblk {
let img = self
.timeline
.get_rel_page_at_lsn(tag, blknum, self.lsn, false, self.ctx)
.get_rel_page_at_lsn(src, blknum, self.lsn, false, self.ctx)
.await?;
segment_data.extend_from_slice(&img[..]);
}
let file_name = tag.to_segfile_name(seg as u32);
let file_name = dst.to_segfile_name(seg as u32);
let header = new_tar_header(&file_name, segment_data.len() as u64)?;
self.ar.append(&header, segment_data.as_slice()).await?;

View File

@@ -23,11 +23,10 @@ use pageserver::{
tenant::mgr,
virtual_file,
};
use postgres_backend::AuthType;
use utils::{
auth::JwtAuth,
logging,
postgres_backend::AuthType,
project_git_version,
logging, project_git_version,
sentry_init::init_sentry,
signals::{self, Signal},
tcp_listener,
@@ -281,33 +280,17 @@ fn start_pageserver(
};
info!("Using auth: {:#?}", conf.auth_type);
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
(old, Ok(v)) => {
match var("NEON_AUTH_TOKEN") {
Ok(v) => {
info!("Loaded JWT token for authentication with Safekeeper");
if let Ok(v_old) = old {
warn!(
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
);
if v_old != v {
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
}
}
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(Ok(v), _) => {
info!("Loaded JWT token for authentication with Safekeeper");
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(_, Err(VarError::NotPresent)) => {
Err(VarError::NotPresent) => {
info!("No JWT token for authentication with Safekeeper detected");
}
(_, Err(e)) => {
Err(e) => {
return Err(e).with_context(|| {
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
})

View File

@@ -21,10 +21,10 @@ use std::time::Duration;
use toml_edit;
use toml_edit::{Document, Item};
use postgres_backend::AuthType;
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
};
use crate::tenant::config::TenantConf;
@@ -698,6 +698,12 @@ impl PageServerConf {
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
}
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
t_conf.image_creation_threshold = Some(
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
);
}
if let Some(gc_horizon) = item.get("gc_horizon") {
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
}

View File

@@ -245,6 +245,53 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
description: Garbage collect given timeline
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/attach:
parameters:
- name: tenant_id

View File

@@ -10,6 +10,7 @@ use remote_storage::GenericRemoteStorage;
use tenant_size_model::{SizeResult, StorageModel};
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::http::endpoint::RequestSpan;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use super::models::{
@@ -971,19 +972,22 @@ async fn timeline_checkpoint_handler(request: Request<Body>) -> Result<Response<
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
check_permission(&request, Some(tenant_id))?;
async {
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
json_response(StatusCode::OK, ())
}
.instrument(info_span!("manual_checkpoint", tenant_id = %tenant_id, timeline_id = %timeline_id))
.await
}
async fn timeline_download_remote_layers_handler_post(
@@ -1088,7 +1092,8 @@ pub fn make_router(
let handler = $handler;
#[cfg(not(feature = "testing"))]
let handler = cfg_disabled;
handler
move |r| RequestSpan(handler).handle(r)
}};
}
@@ -1096,35 +1101,55 @@ pub fn make_router(
.data(Arc::new(
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
))
.get("/v1/status", status_handler)
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
.put(
"/v1/failpoints",
testing_api!("manage failpoints", failpoints_handler),
)
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.get("/v1/tenant/:tenant_id", tenant_status)
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
.put("/v1/tenant/config", update_tenant_config_handler)
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_detail_handler,
)
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
.post("/v1/tenant", |r| {
RequestSpan(tenant_create_handler).handle(r)
})
.get("/v1/tenant/:tenant_id", |r| {
RequestSpan(tenant_status).handle(r)
})
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
RequestSpan(tenant_size_handler).handle(r)
})
.put("/v1/tenant/config", |r| {
RequestSpan(update_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/config", |r| {
RequestSpan(get_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_list_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_create_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/attach", |r| {
RequestSpan(tenant_attach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/detach", |r| {
RequestSpan(tenant_detach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/load", |r| {
RequestSpan(tenant_load_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/ignore", |r| {
RequestSpan(tenant_ignore_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_detail_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
get_lsn_by_timestamp_handler,
)
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
timeline_gc_handler,
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
RequestSpan(timeline_gc_handler).handle(r)
})
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
testing_api!("run timeline compaction", timeline_compact_handler),
@@ -1135,28 +1160,26 @@ pub fn make_router(
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
timeline_download_remote_layers_handler_post,
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
timeline_download_remote_layers_handler_get,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
layer_map_info_handler,
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
)
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_delete_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
RequestSpan(layer_map_info_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
layer_download_handler,
|r| RequestSpan(layer_download_handler).handle(r),
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
evict_timeline_layer_handler,
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
)
.get("/v1/panic", always_panic_handler)
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
.any(handler_404))
}

View File

@@ -123,6 +123,22 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_layers_total",
"Total on-demand downloaded layers"
)
.unwrap()
});
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_bytes_total",
"Total bytes of layers on-demand downloaded",
)
.unwrap()
});
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_current_logical_size",

View File

@@ -20,7 +20,8 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use pq_proto::ConnectionError;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -35,8 +36,6 @@ use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
simple_rcu::RcuReadGuard,
};
@@ -64,11 +63,11 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = format!("pageserver is shutting down");
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg }
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
};
match msg {
@@ -79,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -96,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other))?;
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
}
};
}
@@ -212,7 +214,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
if is_expected_io_error(&io_error) {
info!("Postgres client disconnected ({io_error})");
Ok(())
@@ -311,7 +313,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH
pgb.write_message(&BeMessage::CopyBothResponse)?;
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -380,7 +382,7 @@ impl PageServerHandler {
})
});
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?;
}
Ok(())
@@ -416,7 +418,7 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
@@ -468,7 +470,7 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
@@ -678,7 +680,7 @@ impl PageServerHandler {
}
// switch client to COPYOUT
pgb.write_message(&BeMessage::CopyOutResponse)?;
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline
@@ -695,7 +697,7 @@ impl PageServerHandler {
.await?;
}
pgb.write_message(&BeMessage::CopyDone)?;
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.flush().await?;
info!("basebackup complete");
@@ -721,7 +723,7 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend_async::Handler for PageServerHandler {
impl postgres_backend::Handler for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
@@ -812,7 +814,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -835,15 +837,15 @@ impl postgres_backend_async::Handler for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message(&BeMessage::RowDescription(&[
pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"),
]))?
.write_message(&BeMessage::DataRow(&[
.write_message_noflush(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()),
]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
@@ -884,7 +886,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
@@ -929,10 +931,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
)
.await
{
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message(&BeMessage::ErrorResponse(
pgb.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -965,10 +967,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await
{
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message(&BeMessage::ErrorResponse(
pgb.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -977,7 +979,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") {
// show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len());
@@ -993,7 +995,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message(&BeMessage::RowDescription(&[
pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1004,7 +1006,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"),
]))?
.write_message(&BeMessage::DataRow(&[
.write_message_noflush(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some(
tenant
@@ -1027,7 +1029,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}"
@@ -1055,7 +1057,7 @@ impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
),
GetActiveTenantError::Other(e) => QueryError::Other(e),
}

View File

@@ -103,6 +103,7 @@ pub struct TenantConfOpt {
pub checkpoint_distance: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "humantime_serde")]
#[serde(default)]
pub checkpoint_timeout: Option<Duration>,

View File

@@ -218,9 +218,10 @@ use tracing::{debug, info, warn};
use tracing::{info_span, Instrument};
use utils::lsn::Lsn;
use crate::metrics::RemoteOpFileKind;
use crate::metrics::RemoteOpKind;
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
use crate::metrics::{
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
};
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
use crate::{
config::PageServerConf,
@@ -446,6 +447,10 @@ impl RemoteTimelineClient {
);
}
}
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
Ok(downloaded_size)
}

View File

@@ -6,11 +6,13 @@
use std::collections::HashSet;
use std::future::Future;
use std::path::Path;
use std::time::Duration;
use anyhow::{anyhow, Context};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{error, info, warn};
use tracing::{info, warn};
use crate::config::PageServerConf;
use crate::tenant::storage_layer::LayerFileName;
@@ -26,6 +28,8 @@ async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Er
fs::File::open(path).await?.sync_all().await
}
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
@@ -64,22 +68,28 @@ pub async fn download_layer_file<'a>(
// TODO: this doesn't use the cached fd for some reason?
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
format!(
"Failed to create a destination file for layer '{}'",
"create a destination file for layer '{}'",
temp_file_path.display()
)
})
.map_err(DownloadError::Other)?;
let mut download = storage.download(&remote_path).await.with_context(|| {
format!(
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
"open a download stream for layer with remote storage path '{remote_path:?}'"
)
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
.await
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
Ok((destination_file, bytes_amount))
},
&format!("download {remote_path:?}"),
).await?;
@@ -300,7 +310,7 @@ where
}
Err(DownloadError::Other(ref err)) => {
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
return result;
}
}

View File

@@ -364,7 +364,7 @@ pub trait PersistentLayer: Layer {
}
/// Permanently remove this layer from disk.
fn delete(&self) -> Result<()>;
fn delete_resident_layer_file(&self) -> Result<()>;
fn downcast_remote_layer(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {
None

View File

@@ -438,7 +438,7 @@ impl PersistentLayer for DeltaLayer {
))
}
fn delete(&self) -> Result<()> {
fn delete_resident_layer_file(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())

View File

@@ -252,7 +252,7 @@ impl PersistentLayer for ImageLayer {
unimplemented!();
}
fn delete(&self) -> Result<()> {
fn delete_resident_layer_file(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())

View File

@@ -155,8 +155,8 @@ impl PersistentLayer for RemoteLayer {
bail!("cannot iterate a remote layer");
}
fn delete(&self) -> Result<()> {
Ok(())
fn delete_resident_layer_file(&self) -> Result<()> {
bail!("remote layer has no layer file");
}
fn downcast_remote_layer<'a>(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {

View File

@@ -662,8 +662,8 @@ impl Timeline {
// update the index file on next flush iteration too. But it
// could take a while until that happens.
//
// Additionally, only do this on the terminal round before sleeping.
if last_round {
// Additionally, only do this once before we return from this function.
if last_round || res.is_ok() {
if let Some(remote_client) = &self.remote_client {
remote_client.schedule_index_upload_for_file_changes()?;
}
@@ -1047,11 +1047,12 @@ impl Timeline {
return Ok(false);
}
let layer_metadata = LayerFileMetadata::new(
local_layer
.file_size()
.expect("Local layer should have a file size"),
);
let layer_file_size = local_layer
.file_size()
.expect("Local layer should have a file size");
let layer_metadata = LayerFileMetadata::new(layer_file_size);
let new_remote_layer = Arc::new(match local_layer.filename() {
LayerFileName::Image(image_name) => RemoteLayer::new_img(
self.tenant_id,
@@ -1075,15 +1076,22 @@ impl Timeline {
let replaced = match batch_updates.replace_historic(local_layer, new_remote_layer)? {
Replacement::Replaced { .. } => {
let layer_size = local_layer.file_size();
if let Err(e) = local_layer.delete() {
if let Err(e) = local_layer.delete_resident_layer_file() {
error!("failed to remove layer file on evict after replacement: {e:#?}");
}
if let Some(layer_size) = layer_size {
self.metrics.resident_physical_size_gauge.sub(layer_size);
}
// Always decrement the physical size gauge, even if we failed to delete the file.
// Rationale: we already replaced the layer with a remote layer in the layer map,
// and any subsequent download_remote_layer will
// 1. overwrite the file on disk and
// 2. add the downloaded size to the resident size gauge.
//
// If there is no re-download, and we restart the pageserver, then load_layer_map
// will treat the file as a local layer again, count it towards resident size,
// and it'll be like the layer removal never happened.
// The bump in resident size is perhaps unexpected but overall a robust behavior.
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
true
}
@@ -1942,11 +1950,14 @@ impl Timeline {
layer: Arc<dyn PersistentLayer>,
updates: &mut BatchedUpdates<'_, dyn PersistentLayer>,
) -> anyhow::Result<()> {
let layer_size = layer.file_size();
layer.delete()?;
if let Some(layer_size) = layer_size {
self.metrics.resident_physical_size_gauge.sub(layer_size);
if !layer.is_remote_layer() {
layer.delete_resident_layer_file()?;
let layer_file_size = layer
.file_size()
.expect("Local layer should have a file size");
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
}
// TODO Removing from the bottom of the layer map is expensive.
@@ -3808,7 +3819,7 @@ impl Timeline {
remote_layer.ongoing_download.close();
} else {
// Keep semaphore open. We'll drop the permit at the end of the function.
info!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
error!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
}
// Don't treat it as an error if the task that triggered the download

View File

@@ -33,10 +33,11 @@ use crate::{
walingest::WalIngest,
walrecord::DecodedWALRecord,
};
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use pq_proto::ReplicationFeedback;
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
use utils::lsn::Lsn;
/// Status of the connection.
#[derive(Debug, Clone, Copy)]
@@ -353,7 +354,7 @@ pub async fn handle_walreceiver_connection(
debug!("neon_status_update {status_update:?}");
let mut data = BytesMut::new();
status_update.serialize(&mut data)?;
status_update.serialize(&mut data);
physical_stream
.as_mut()
.zenith_status_update(data.len() as u64, &data)
@@ -434,8 +435,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
{
return Ok(pg_error);
} else if let Some(db_error) = pg_error.as_db_error() {
if db_error.code() == &SqlState::CONNECTION_FAILURE
&& db_error.message().contains("end streaming")
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
&& db_error.message().contains("ending streaming")
{
return Ok(pg_error);
}

View File

@@ -37,7 +37,7 @@ use crate::walrecord::*;
use crate::ZERO_PAGE;
use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::{FSM_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::relfile_utils::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::v14::nonrelfile_utils::mx_offset_to_member_segment;
use postgres_ffi::v14::xlog_utils::*;
use postgres_ffi::v14::CheckPoint;
@@ -762,7 +762,7 @@ impl<'a> WalIngest<'a> {
)?;
for xnode in &parsed.xnodes {
for forknum in MAIN_FORKNUM..=VISIBILITYMAP_FORKNUM {
for forknum in MAIN_FORKNUM..=INIT_FORKNUM {
let rel = RelTag {
forknum,
spcnode: xnode.spcnode,

View File

@@ -23,13 +23,11 @@ use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::prelude::CommandExt;
use std::path::PathBuf;
use std::process::Stdio;
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
@@ -256,52 +254,53 @@ impl PostgresRedoManager {
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
loop {
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
@@ -310,24 +309,28 @@ impl PostgresRedoManager {
base_img_lsn,
lsn
);
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
n_attempts += 1;
if n_attempts > MAX_RETRY_ATTEMPTS || result.is_ok() {
return result;
}
}
result
}
///
@@ -634,26 +637,26 @@ impl PostgresRedoManager {
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> Result<(), Error> {
// FIXME: We need a dummy Postgres cluster to run the process in. Currently, we
// just create one with constant name. That fails if you try to launch more than
// one WAL redo manager concurrently.
let datadir = path_with_suffix_extension(
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
// Create empty data directory for wal-redo postgres, deleting old one first.
if datadir.exists() {
info!("old temporary datadir {datadir:?} exists, removing");
fs::remove_dir_all(&datadir).map_err(|e| {
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
Error::new(
e.kind(),
format!("Old temporary dir {datadir:?} removal failure: {e}"),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
@@ -663,35 +666,6 @@ impl PostgresRedoManager {
.pg_lib_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
info!("running initdb in {}", datadir.display());
let initdb = Command::new(pg_bin_dir_path.join("initdb"))
.args(["-D", &datadir.to_string_lossy()])
.arg("-N")
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) // macOS
.close_fds()
.output()
.map_err(|e| Error::new(e.kind(), format!("failed to execute initdb: {e}")))?;
if !initdb.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!(
"initdb failed\nstdout: {}\nstderr:\n{}",
String::from_utf8_lossy(&initdb.stdout),
String::from_utf8_lossy(&initdb.stderr)
),
));
} else {
// Limit shared cache for wal-redo-postgres
let mut config = OpenOptions::new()
.append(true)
.open(PathBuf::from(&datadir).join("postgresql.conf"))?;
config.write_all(b"shared_buffers=128kB\n")?;
config.write_all(b"fsync=off\n")?;
}
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
.arg("--wal-redo")
@@ -701,7 +675,6 @@ impl PostgresRedoManager {
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
.env("PGDATA", &datadir)
// The redo process is not trusted, and runs in seccomp mode that
// doesn't allow it to open any files. We have to also make sure it
// doesn't inherit any file descriptors from the pageserver, that
@@ -771,7 +744,7 @@ impl PostgresRedoManager {
&self,
mut input: MutexGuard<Option<ProcessInput>>,
tag: BufferTag,
base_img: Option<Bytes>,
base_img: &Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
@@ -787,7 +760,7 @@ impl PostgresRedoManager {
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
build_push_page_msg(tag, &img, &mut writebuf);
build_push_page_msg(tag, img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {

View File

@@ -32,6 +32,9 @@
#define PageStoreTrace DEBUG5
#define MAX_RECONNECT_ATTEMPTS 5
#define RECONNECT_INTERVAL_USEC 1000000
bool connected = false;
PGconn *pageserver_conn = NULL;
@@ -52,8 +55,8 @@ int readahead_buffer_size = 128;
static void pageserver_flush(void);
static void
pageserver_connect()
static bool
pageserver_connect(int elevel)
{
char *query;
int ret;
@@ -69,10 +72,11 @@ pageserver_connect()
PQfinish(pageserver_conn);
pageserver_conn = NULL;
ereport(ERROR,
ereport(elevel,
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
errmsg(NEON_TAG "could not establish connection to pageserver"),
errdetail_internal("%s", msg)));
return false;
}
query = psprintf("pagestream %s %s", neon_tenant, neon_timeline);
@@ -81,7 +85,8 @@ pageserver_connect()
{
PQfinish(pageserver_conn);
pageserver_conn = NULL;
neon_log(ERROR, "could not send pagestream command to pageserver");
neon_log(elevel, "could not send pagestream command to pageserver");
return false;
}
pageserver_conn_wes = CreateWaitEventSet(TopMemoryContext, 3);
@@ -113,8 +118,9 @@ pageserver_connect()
FreeWaitEventSet(pageserver_conn_wes);
pageserver_conn_wes = NULL;
neon_log(ERROR, "could not complete handshake with pageserver: %s",
neon_log(elevel, "could not complete handshake with pageserver: %s",
msg);
return false;
}
}
}
@@ -122,6 +128,7 @@ pageserver_connect()
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring_raw);
connected = true;
return true;
}
/*
@@ -149,8 +156,11 @@ retry:
if (event.events & WL_SOCKET_READABLE)
{
if (!PQconsumeInput(pageserver_conn))
neon_log(ERROR, "could not get response from pageserver: %s",
{
neon_log(LOG, "could not get response from pageserver: %s",
PQerrorMessage(pageserver_conn));
return -1;
}
}
goto retry;
@@ -190,31 +200,62 @@ static void
pageserver_send(NeonRequest * request)
{
StringInfoData req_buff;
int n_reconnect_attempts = 0;
/* If the connection was lost for some reason, reconnect */
if (connected && PQstatus(pageserver_conn) == CONNECTION_BAD)
pageserver_disconnect();
if (!connected)
pageserver_connect();
req_buff = nm_pack_request(request);
/*
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
* If pageserver is stopped, the connections from compute node are broken.
* The compute node doesn't notice that immediately, but it will cause the next request to fail, usually on the next query.
* That causes user-visible errors if pageserver is restarted, or the tenant is moved from one pageserver to another.
* See https://github.com/neondatabase/neon/issues/1138
* So try to reestablish connection in case of failure.
*/
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
while (true)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
if (!connected)
{
if (!pageserver_connect(n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS ? LOG : ERROR))
{
n_reconnect_attempts += 1;
pg_usleep(RECONNECT_INTERVAL_USEC);
continue;
}
}
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
/*
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
if (n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS)
{
neon_log(LOG, "failed to send page request (try to reconnect): %s", msg);
if (n_reconnect_attempts != 0) /* do not sleep before first reconnect attempt, assuming that pageserver is already restarted */
pg_usleep(RECONNECT_INTERVAL_USEC);
n_reconnect_attempts += 1;
continue;
}
else
{
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
}
break;
}
pfree(req_buff.data);
n_unflushed_requests++;

15
pgxn/neon_utils/Makefile Normal file
View File

@@ -0,0 +1,15 @@
# pgxs/neon_utils/Makefile
MODULE_big = neon_utils
OBJS = \
$(WIN32RES) \
neon_utils.o
EXTENSION = neon_utils
DATA = neon_utils--1.0.sql
PGFILEDESC = "neon_utils - small useful functions"
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
include $(PGXS)

View File

@@ -0,0 +1,6 @@
CREATE FUNCTION num_cpus()
RETURNS int
AS 'MODULE_PATHNAME', 'num_cpus'
LANGUAGE C STRICT
PARALLEL UNSAFE
VOLATILE;

View File

@@ -0,0 +1,35 @@
/*-------------------------------------------------------------------------
*
* neon_utils.c
* neon_utils - small useful functions
*
* IDENTIFICATION
* contrib/neon_utils/neon_utils.c
*
*-------------------------------------------------------------------------
*/
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
#include "postgres.h"
#include "fmgr.h"
PG_MODULE_MAGIC;
PG_FUNCTION_INFO_V1(num_cpus);
Datum
num_cpus(PG_FUNCTION_ARGS)
{
#ifdef _WIN32
SYSTEM_INFO sysinfo;
GetSystemInfo(&sysinfo);
uint32 num_cpus = (uint32) sysinfo.dwNumberOfProcessors;
#else
uint32 num_cpus = (uint32) sysconf(_SC_NPROCESSORS_ONLN);
#endif
PG_RETURN_UINT32(num_cpus);
}

View File

@@ -0,0 +1,6 @@
# neon_utils extension
comment = 'neon_utils - small useful functions'
default_version = '1.0'
module_pathname = '$libdir/neon_utils'
relocatable = true
trusted = true

View File

@@ -65,6 +65,14 @@
#include "rusagestub.h"
#endif
#include "access/clog.h"
#include "access/commit_ts.h"
#include "access/heapam.h"
#include "access/multixact.h"
#include "access/nbtree.h"
#include "access/subtrans.h"
#include "access/syncscan.h"
#include "access/twophase.h"
#include "access/xlog.h"
#include "access/xlog_internal.h"
#if PG_VERSION_NUM >= 150000
@@ -72,18 +80,36 @@
#endif
#include "access/xlogutils.h"
#include "catalog/pg_class.h"
#include "libpq/libpq.h"
#include "commands/async.h"
#include "libpq/pqformat.h"
#include "miscadmin.h"
#include "pgstat.h"
#include "postmaster/autovacuum.h"
#include "postmaster/bgworker_internals.h"
#include "postmaster/bgwriter.h"
#include "postmaster/postmaster.h"
#include "replication/logicallauncher.h"
#include "replication/origin.h"
#include "replication/slot.h"
#include "replication/walreceiver.h"
#include "replication/walsender.h"
#include "storage/buf_internals.h"
#include "storage/bufmgr.h"
#include "storage/dsm.h"
#include "storage/ipc.h"
#include "storage/pg_shmem.h"
#include "storage/pmsignal.h"
#include "storage/predicate.h"
#include "storage/proc.h"
#include "storage/procarray.h"
#include "storage/procsignal.h"
#include "storage/sinvaladt.h"
#include "storage/smgr.h"
#include "storage/spin.h"
#include "tcop/tcopprot.h"
#include "utils/memutils.h"
#include "utils/ps_status.h"
#include "utils/snapmgr.h"
#include "inmem_smgr.h"
@@ -101,6 +127,7 @@ static void apply_error_callback(void *arg);
static bool redo_block_filter(XLogReaderState *record, uint8 block_id);
static void GetPage(StringInfo input_message);
static ssize_t buffered_read(void *buf, size_t count);
static void CreateFakeSharedMemoryAndSemaphores();
static BufferTag target_redo_tag;
@@ -141,7 +168,7 @@ enter_seccomp_mode(void)
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), // shm_unlink
*/
*/
};
#ifdef MALLOC_NO_MMAP
@@ -177,6 +204,7 @@ WalRedoMain(int argc, char *argv[])
* buffers. So let's keep it small (default value is 1024)
*/
num_temp_buffers = 4;
NBuffers = 4;
/*
* install the simple in-memory smgr
@@ -184,49 +212,33 @@ WalRedoMain(int argc, char *argv[])
smgr_hook = smgr_inmem;
smgr_init_hook = smgr_init_inmem;
/*
* Validate we have been given a reasonable-looking DataDir and change into it.
*/
checkDataDir();
ChangeToDataDir();
/*
* Create lockfile for data directory.
*/
CreateDataDirLockFile(false);
/* read control file (error checking and contains config ) */
LocalProcessControlFile(false);
/*
* process any libraries that should be preloaded at postmaster start
*/
process_shared_preload_libraries();
/* Initialize MaxBackends (if under postmaster, was done already) */
MaxConnections = 1;
max_worker_processes = 0;
max_parallel_workers = 0;
max_wal_senders = 0;
InitializeMaxBackends();
#if PG_VERSION_NUM >= 150000
/*
* Give preloaded libraries a chance to request additional shared memory.
*/
process_shmem_requests();
/* Disable lastWrittenLsnCache */
lastWrittenLsnCacheSize = 0;
/*
* Now that loadable modules have had their chance to request additional
* shared memory, determine the value of any runtime-computed GUCs that
* depend on the amount of shared memory required.
*/
#if PG_VERSION_NUM >= 150000
process_shmem_requests();
InitializeShmemGUCs();
/*
* Now that modules have been loaded, we can process any custom resource
* managers specified in the wal_consistency_checking GUC.
* This will try to access data directory which we do not set.
* Seems to be pretty safe to disable.
*/
InitializeWalConsistencyChecking();
/* InitializeWalConsistencyChecking(); */
#endif
CreateSharedMemoryAndSemaphores();
/*
* We have our own version of CreateSharedMemoryAndSemaphores() that
* sets up local memory instead of shared one.
*/
CreateFakeSharedMemoryAndSemaphores();
/*
* Remember stand-alone backend startup time,roughly at the same point
@@ -354,6 +366,172 @@ WalRedoMain(int argc, char *argv[])
}
/*
* Initialize dummy shmem.
*
* This code follows CreateSharedMemoryAndSemaphores() but manually sets up
* the shmem header and skips few initialization steps that are not needed for
* WAL redo.
*
* I've also tried removing most of initialization functions that request some
* memory (like ApplyLauncherShmemInit and friends) but in reality it haven't had
* any sizeable effect on RSS, so probably such clean up not worth the risk of having
* half-initialized postgres.
*/
static void
CreateFakeSharedMemoryAndSemaphores()
{
PGShmemHeader *shim = NULL;
PGShmemHeader *hdr;
Size size;
int numSemas;
char cwd[MAXPGPATH];
#if PG_VERSION_NUM >= 150000
size = CalculateShmemSize(&numSemas);
#else
/*
* Postgres v14 doesn't have a separate CalculateShmemSize(). Use result of the
* corresponging calculation in CreateSharedMemoryAndSemaphores()
*/
size = 1409024;
numSemas = 10;
#endif
/* Dummy implementation of PGSharedMemoryCreate() */
{
hdr = (PGShmemHeader *) malloc(size);
if (!hdr)
ereport(FATAL,
(errcode(ERRCODE_OUT_OF_MEMORY),
errmsg("[neon-wal-redo] can not allocate (pseudo-) shared memory")));
hdr->creatorPID = getpid();
hdr->magic = PGShmemMagic;
hdr->dsm_control = 0;
hdr->device = 42; /* not relevant for non-shared memory */
hdr->inode = 43; /* not relevant for non-shared memory */
hdr->totalsize = size;
hdr->freeoffset = MAXALIGN(sizeof(PGShmemHeader));
shim = hdr;
UsedShmemSegAddr = hdr;
UsedShmemSegID = (unsigned long) 42; /* not relevant for non-shared memory */
}
InitShmemAccess(hdr);
/*
* Reserve semaphores uses dir name as a source of entropy. Set it to cwd(). Rest
* of the code does not need DataDir access so nullify DataDir after
* PGReserveSemaphores() to error out if something will try to access it.
*/
if (!getcwd(cwd, MAXPGPATH))
ereport(FATAL,
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("[neon-wal-redo] can not read current directory name")));
DataDir = cwd;
PGReserveSemaphores(numSemas);
DataDir = NULL;
/*
* The rest of function follows CreateSharedMemoryAndSemaphores() closely,
* skipped parts are marked with comments.
*/
InitShmemAllocation();
/*
* Now initialize LWLocks, which do shared memory allocation and are
* needed for InitShmemIndex.
*/
CreateLWLocks();
/*
* Set up shmem.c index hashtable
*/
InitShmemIndex();
dsm_shmem_init();
/*
* Set up xlog, clog, and buffers
*/
XLOGShmemInit();
CLOGShmemInit();
CommitTsShmemInit();
SUBTRANSShmemInit();
MultiXactShmemInit();
InitBufferPool();
/*
* Set up lock manager
*/
InitLocks();
/*
* Set up predicate lock manager
*/
InitPredicateLocks();
/*
* Set up process table
*/
if (!IsUnderPostmaster)
InitProcGlobal();
CreateSharedProcArray();
CreateSharedBackendStatus();
TwoPhaseShmemInit();
BackgroundWorkerShmemInit();
/*
* Set up shared-inval messaging
*/
CreateSharedInvalidationState();
/*
* Set up interprocess signaling mechanisms
*/
PMSignalShmemInit();
ProcSignalShmemInit();
CheckpointerShmemInit();
AutoVacuumShmemInit();
ReplicationSlotsShmemInit();
ReplicationOriginShmemInit();
WalSndShmemInit();
WalRcvShmemInit();
PgArchShmemInit();
ApplyLauncherShmemInit();
/*
* Set up other modules that need some shared memory space
*/
SnapMgrInit();
BTreeShmemInit();
SyncScanShmemInit();
/* Skip due to the 'pg_notify' directory check */
/* AsyncShmemInit(); */
#ifdef EXEC_BACKEND
/*
* Alloc the win32 shared backend array
*/
if (!IsUnderPostmaster)
ShmemBackendArrayAllocation();
#endif
/* Initialize dynamic shared memory facilities. */
if (!IsUnderPostmaster)
dsm_postmaster_startup(shim);
/*
* Now give loadable modules a chance to set up their shmem allocations
*/
if (shmem_startup_hook)
shmem_startup_hook();
}
/* Version compatility wrapper for ReadBufferWithoutRelcache */
static inline Buffer
NeonRedoReadBuffer(RelFileNode rnode,

View File

@@ -31,6 +31,7 @@ once_cell.workspace = true
opentelemetry.workspace = true
parking_lot.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true

View File

@@ -25,12 +25,11 @@ impl CancelMap {
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
/// Create a new session, with a new client-facing random cancellation key.
///
/// Use `enable_query_cancellation` to register the Postgres backend's cancellation
/// key with it.
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -44,17 +43,9 @@ impl CancelMap {
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
Ok(Session::new(key, self))
}
#[cfg(test)]
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map
.0
@@ -122,6 +113,14 @@ impl Session<'_> {
}
}
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.write().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -132,14 +131,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
});
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;

View File

@@ -4,13 +4,11 @@ use crate::{
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::{net::TcpStream, thread};
use std::future;
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info, info_span};
use utils::{
postgres_backend::{self, AuthType, PostgresBackend},
postgres_backend_async::QueryError,
};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
@@ -33,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
/// Console management API listener task.
/// It spawns console response handlers needed for the link auth.
pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> {
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
info!("mgmt has shut down");
}
@@ -42,18 +40,12 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()>
let (socket, peer_addr) = listener.accept().await?;
info!("accepted connection from {peer_addr}");
let socket = socket.into_std()?;
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
socket
.set_nonblocking(false)
.context("failed to set client socket option")?;
// TODO: replace with async tasks.
thread::spawn(move || {
let tid = std::thread::current().id();
let span = info_span!("mgmt", thread = format_args!("{tid:?}"));
tokio::task::spawn(async move {
let span = info_span!("mgmt", peer = %peer_addr);
let _enter = span.enter();
info!("started a new console management API thread");
@@ -61,16 +53,16 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()>
info!("console management API thread is about to finish");
}
if let Err(e) = handle_connection(socket) {
if let Err(e) = handle_connection(socket).await {
error!("thread failed with an error: {e}");
}
});
}
}
fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
pgbackend.run(&mut MgmtHandler)
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
}
/// A message received by `mgmt` when a compute node is ready.
@@ -78,16 +70,21 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler for MgmtHandler {
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
try_process_query(pgb, query).map_err(|e| {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query: &str,
) -> Result<(), QueryError> {
try_process_query(pgb, query).await.map_err(|e| {
error!("failed to process response: {e:?}");
e
})
}
}
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);
@@ -98,11 +95,11 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
Err(e) => {
error!("failed to deliver response to per-client task");
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?;
pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
}
}

View File

@@ -21,6 +21,7 @@ const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
#[derive(Eq, Hash, PartialEq, Serialize, Debug)]
pub struct Ids {
pub endpoint_id: String,
pub branch_id: String,
}
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
@@ -74,12 +75,23 @@ fn gather_proxy_io_bytes_per_client() -> Vec<(Ids, (u64, DateTime<Utc>))> {
.find(|l| l.get_name() == "endpoint_id")
.unwrap()
.get_value();
let branch_id = ms
.get_label()
.iter()
.find(|l| l.get_name() == "branch_id")
.unwrap()
.get_value();
let value = ms.get_counter().get_value() as u64;
debug!("endpoint_id:val - {}: {}", endpoint_id, value);
debug!(
"branch_id {} endpoint_id {} val: {}",
branch_id, endpoint_id, value
);
current_metrics.push((
Ids {
endpoint_id: endpoint_id.to_string(),
branch_id: "".to_string(),
},
(value, Utc::now()),
));
@@ -131,6 +143,7 @@ async fn collect_metrics_iteration(
value,
extra: Ids {
endpoint_id: curr_key.endpoint_id.clone(),
branch_id: curr_key.branch_id.clone(),
},
})
})
@@ -172,6 +185,7 @@ async fn collect_metrics_iteration(
cached_metrics
.entry(Ids {
endpoint_id: send_metric.extra.endpoint_id.clone(),
branch_id: send_metric.extra.branch_id.clone(),
})
// update cached value (add delta) and time
.and_modify(|e| {

View File

@@ -133,10 +133,14 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(true).await
}
#[tracing::instrument(fields(session_id), skip_all)]
@@ -172,10 +176,14 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(false).await
}
/// Establish a (most probably, secure) connection with the client.
@@ -381,6 +389,8 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
}
impl<'a, S> Client<'a, S> {
@@ -390,28 +400,27 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
session,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
allow_cleartext: bool,
) -> anyhow::Result<()> {
async fn connect_to_db(self, allow_cleartext: bool) -> anyhow::Result<()> {
let Self {
mut stream,
mut creds,
params,
session_id,
session,
} = self;
let extra = console::ConsoleReqExtra {

View File

@@ -1,45 +1,40 @@
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
pin_project! {
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
#[pin]
stream: S,
buffer: BytesMut,
}
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
framed: Framed<S>,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
Self {
stream,
buffer: Default::default(),
framed: Framed::new(stream),
}
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
self.framed.into_inner()
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
self.framed.get_ref()
}
}
@@ -50,16 +45,19 @@ fn err_connection() -> io::Error {
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
let msg = FeStartupPacket::read_fut(&mut self.stream)
self.framed
.read_startup_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)?;
.ok_or_else(err_connection)
}
match msg {
FeMessage::StartupPacket(packet) => Ok(packet),
_ => panic!("unreachable state"),
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
self.framed
.read_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
@@ -71,19 +69,14 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
)),
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buffer, message)?;
self.framed
.write_message(message)
.map_err(ProtocolError::into_io_error)?;
Ok(self)
}
@@ -96,9 +89,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Flush the output buffer into the underlying stream.
pub async fn flush(&mut self) -> io::Result<&mut Self> {
self.stream.write_all(&self.buffer).await?;
self.buffer.clear();
self.stream.flush().await?;
self.framed.flush().await?;
Ok(self)
}

View File

@@ -11,12 +11,18 @@
# Not every feature is supported in macOS builds. Avoid running regular linting
# script that checks every feature.
#
# manual-range-contains wants
# !(4..=MAX_STARTUP_PACKET_LENGTH).contains(&len)
# instead of
# len < 4 || len > MAX_STARTUP_PACKET_LENGTH
# , let's disagree.
if [[ "$OSTYPE" == "darwin"* ]]; then
# no extra features to test currently, add more here when needed
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings
else
# * `-A unknown_lints` do not warn about unknown lint suppressions
# that people with newer toolchains might use
# * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status)
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings
fi

View File

@@ -10,6 +10,7 @@ anyhow.workspace = true
async-trait.workspace = true
byteorder.workspace = true
bytes.workspace = true
chrono.workspace = true
clap = { workspace = true, features = ["derive"] }
const_format.workspace = true
crc32c.workspace = true
@@ -18,7 +19,6 @@ git-version.workspace = true
hex.workspace = true
humantime.workspace = true
hyper.workspace = true
nix.workspace = true
once_cell.workspace = true
parking_lot.workspace = true
postgres.workspace = true
@@ -35,6 +35,7 @@ toml_edit.workspace = true
tracing.workspace = true
url.workspace = true
metrics.workspace = true
postgres_backend.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true
remote_storage.workspace = true

View File

@@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
let conf_cloned = conf.clone();
let safekeeper_thread = thread::Builder::new()
.name("safekeeper thread".into())
.name("WAL service thread".into())
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
.unwrap();

View File

@@ -0,0 +1,264 @@
//! Utils for dumping full state of the safekeeper.
use std::fs;
use std::fs::DirEntry;
use std::io::BufReader;
use std::io::Read;
use std::path::PathBuf;
use anyhow::Result;
use chrono::{DateTime, Utc};
use postgres_ffi::XLogSegNo;
use serde::Serialize;
use utils::http::json::display_serialize;
use utils::id::NodeId;
use utils::id::TenantTimelineId;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::safekeeper::SafeKeeperState;
use crate::safekeeper::SafekeeperMemState;
use crate::safekeeper::TermHistory;
use crate::SafeKeeperConf;
use crate::timeline::ReplicaState;
use crate::GlobalTimelines;
/// Various filters that influence the resulting JSON output.
#[derive(Debug, Serialize)]
pub struct Args {
/// Dump all available safekeeper state. False by default.
pub dump_all: bool,
/// Dump control_file content. Uses value of `dump_all` by default.
pub dump_control_file: bool,
/// Dump in-memory state. Uses value of `dump_all` by default.
pub dump_memory: bool,
/// Dump all disk files in a timeline directory. Uses value of `dump_all` by default.
pub dump_disk_content: bool,
/// Dump full term history. True by default.
pub dump_term_history: bool,
/// Filter timelines by tenant_id.
pub tenant_id: Option<TenantId>,
/// Filter timelines by timeline_id.
pub timeline_id: Option<TimelineId>,
}
/// Response for debug dump request.
#[derive(Debug, Serialize)]
pub struct Response {
pub start_time: DateTime<Utc>,
pub finish_time: DateTime<Utc>,
pub timelines: Vec<Timeline>,
pub timelines_count: usize,
pub config: Config,
}
/// Safekeeper configuration.
#[derive(Debug, Serialize)]
pub struct Config {
pub id: NodeId,
pub workdir: PathBuf,
pub listen_pg_addr: String,
pub listen_http_addr: String,
pub no_sync: bool,
pub max_offloader_lag_bytes: u64,
pub wal_backup_enabled: bool,
}
#[derive(Debug, Serialize)]
pub struct Timeline {
#[serde(serialize_with = "display_serialize")]
pub tenant_id: TenantId,
#[serde(serialize_with = "display_serialize")]
pub timeline_id: TimelineId,
pub control_file: Option<SafeKeeperState>,
pub memory: Option<Memory>,
pub disk_content: Option<DiskContent>,
}
#[derive(Debug, Serialize)]
pub struct Memory {
pub is_cancelled: bool,
pub peers_info_len: usize,
pub replicas: Vec<Option<ReplicaState>>,
pub wal_backup_active: bool,
pub active: bool,
pub num_computes: u32,
pub last_removed_segno: XLogSegNo,
pub epoch_start_lsn: Lsn,
pub mem_state: SafekeeperMemState,
// PhysicalStorage state.
pub write_lsn: Lsn,
pub write_record_lsn: Lsn,
pub flush_lsn: Lsn,
pub file_open: bool,
}
#[derive(Debug, Serialize)]
pub struct DiskContent {
pub files: Vec<FileInfo>,
}
#[derive(Debug, Serialize)]
pub struct FileInfo {
pub name: String,
pub size: u64,
pub created: DateTime<Utc>,
pub modified: DateTime<Utc>,
pub start_zeroes: u64,
pub end_zeroes: u64,
// TODO: add sha256 checksum
}
/// Build debug dump response, using the provided [`Args`] filters.
pub fn build(args: Args) -> Result<Response> {
let start_time = Utc::now();
let timelines_count = GlobalTimelines::timelines_count();
let ptrs_snapshot = if args.tenant_id.is_some() && args.timeline_id.is_some() {
// If both tenant_id and timeline_id are specified, we can just get the
// timeline directly, without taking a snapshot of the whole list.
let ttid = TenantTimelineId::new(args.tenant_id.unwrap(), args.timeline_id.unwrap());
if let Ok(tli) = GlobalTimelines::get(ttid) {
vec![tli]
} else {
vec![]
}
} else {
// Otherwise, take a snapshot of the whole list.
GlobalTimelines::get_all()
};
// TODO: return Stream instead of Vec
let mut timelines = Vec::new();
for tli in ptrs_snapshot {
let ttid = tli.ttid;
if let Some(tenant_id) = args.tenant_id {
if tenant_id != ttid.tenant_id {
continue;
}
}
if let Some(timeline_id) = args.timeline_id {
if timeline_id != ttid.timeline_id {
continue;
}
}
let control_file = if args.dump_control_file {
let mut state = tli.get_state().1;
if !args.dump_term_history {
state.acceptor_state.term_history = TermHistory(vec![]);
}
Some(state)
} else {
None
};
let memory = if args.dump_memory {
Some(tli.memory_dump())
} else {
None
};
let disk_content = if args.dump_disk_content {
// build_disk_content can fail, but we don't want to fail the whole
// request because of that.
build_disk_content(&tli.timeline_dir).ok()
} else {
None
};
let timeline = Timeline {
tenant_id: ttid.tenant_id,
timeline_id: ttid.timeline_id,
control_file,
memory,
disk_content,
};
timelines.push(timeline);
}
let config = GlobalTimelines::get_global_config();
Ok(Response {
start_time,
finish_time: Utc::now(),
timelines,
timelines_count,
config: build_config(config),
})
}
/// Builds DiskContent from a directory path. It can fail if the directory
/// is deleted between the time we get the path and the time we try to open it.
fn build_disk_content(path: &std::path::Path) -> Result<DiskContent> {
let mut files = Vec::new();
for entry in fs::read_dir(path)? {
if entry.is_err() {
continue;
}
let file = build_file_info(entry?);
if file.is_err() {
continue;
}
files.push(file?);
}
Ok(DiskContent { files })
}
/// Builds FileInfo from DirEntry. Sometimes it can return an error
/// if the file is deleted between the time we get the DirEntry
/// and the time we try to open it.
fn build_file_info(entry: DirEntry) -> Result<FileInfo> {
let metadata = entry.metadata()?;
let path = entry.path();
let name = path
.file_name()
.and_then(|x| x.to_str())
.unwrap_or("")
.to_owned();
let mut file = fs::File::open(path)?;
let mut reader = BufReader::new(&mut file).bytes().filter_map(|x| x.ok());
let start_zeroes = reader.by_ref().take_while(|&x| x == 0).count() as u64;
let mut end_zeroes = 0;
for b in reader {
if b == 0 {
end_zeroes += 1;
} else {
end_zeroes = 0;
}
}
Ok(FileInfo {
name,
size: metadata.len(),
created: DateTime::from(metadata.created()?),
modified: DateTime::from(metadata.modified()?),
start_zeroes,
end_zeroes,
})
}
/// Converts SafeKeeperConf to Config, filtering out the fields that are not
/// supposed to be exposed.
fn build_config(config: SafeKeeperConf) -> Config {
Config {
id: config.my_id,
workdir: config.workdir,
listen_pg_addr: config.listen_pg_addr,
listen_http_addr: config.listen_http_addr,
no_sync: config.no_sync,
max_offloader_lag_bytes: config.max_offloader_lag_bytes,
wal_backup_enabled: config.wal_backup_enabled,
}
}

View File

@@ -1,27 +1,24 @@
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
//! protocol commands.
use anyhow::Context;
use std::str;
use tracing::{info, info_span, Instrument};
use crate::auth::check_permission;
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::ReplicationConn;
use crate::wal_service::ConnectionId;
use crate::{GlobalTimelines, SafeKeeperConf};
use anyhow::Context;
use postgres_backend::QueryError;
use postgres_backend::{self, PostgresBackend};
use postgres_ffi::PG_TLI;
use regex::Regex;
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
use std::str;
use tracing::info;
use regex::Regex;
use utils::auth::{Claims, Scope};
use utils::postgres_backend_async::QueryError;
use utils::{
id::{TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::{self, PostgresBackend},
};
/// Safekeeper handler of postgres commands
@@ -32,6 +29,8 @@ pub struct SafekeeperPostgresHandler {
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
pub ttid: TenantTimelineId,
/// Unique connection id is logged in spans for observability.
pub conn_id: ConnectionId,
claims: Option<Claims>,
}
@@ -53,7 +52,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
let start_lsn = caps
.next()
.map(|cap| cap[1].parse::<Lsn>())
.context("failed to parse start LSN from START_REPLICATION command")??;
.context("parse start LSN from START_REPLICATION command")??;
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
Ok(SafekeeperPostgresCommand::IdentifySystem)
@@ -67,6 +66,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
}
}
#[async_trait::async_trait]
impl postgres_backend::Handler for SafekeeperPostgresHandler {
// tenant_id and timeline_id are passed in connection string params
fn startup(
@@ -137,7 +137,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
Ok(())
}
fn process_query(
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
@@ -147,9 +147,10 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
.starts_with("set datestyle to ")
{
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
return Ok(());
}
let cmd = parse_cmd(query_string)?;
info!(
@@ -161,38 +162,36 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
let timeline_id = self.timeline_id.context("timelineid is required")?;
self.check_permission(Some(tenant_id))?;
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
let span_ttid = self.ttid; // satisfy borrow checker
let res = match cmd {
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
match cmd {
SafekeeperPostgresCommand::StartWalPush => {
self.handle_start_wal_push(pgb)
.instrument(info_span!("WAL receiver", ttid = %span_ttid))
.await
}
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
self.handle_start_replication(pgb, start_lsn)
.instrument(info_span!("WAL sender", ttid = %span_ttid))
.await
}
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
};
match res {
Ok(()) => Ok(()),
Err(QueryError::Disconnected(connection_error)) => {
info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}");
Err(QueryError::Disconnected(connection_error))
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
handle_json_ctrl(self, pgb, cmd).await
}
Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!(
"Failed to process query for timeline {}",
self.ttid
)))),
}
}
}
impl SafekeeperPostgresHandler {
pub fn new(conf: SafeKeeperConf) -> Self {
pub fn new(conf: SafeKeeperConf, conn_id: u32) -> Self {
SafekeeperPostgresHandler {
conf,
appname: None,
tenant_id: None,
timeline_id: None,
ttid: TenantTimelineId::empty(),
conn_id,
claims: None,
}
}
@@ -217,8 +216,11 @@ impl SafekeeperPostgresHandler {
///
/// Handle IDENTIFY_SYSTEM replication command
///
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid)?;
async fn handle_identify_system(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
let lsn = if self.is_walproposer_recovery() {
// walproposer should get all local WAL until flush_lsn
@@ -267,7 +269,7 @@ impl SafekeeperPostgresHandler {
Some(lsn_bytes),
None,
]))?
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
.write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
Ok(())
}

View File

@@ -119,6 +119,12 @@ paths:
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
"404":
description: Timeline not found
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
delete:
tags:

View File

@@ -1,18 +1,19 @@
use hyper::{Body, Request, Response, StatusCode, Uri};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_ffi::WAL_SEGMENT_SIZE;
use safekeeper_api::models::SkTimelineInfo;
use serde::Serialize;
use serde::Serializer;
use std::collections::{HashMap, HashSet};
use std::fmt::Display;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use storage_broker::proto::SafekeeperTimelineInfo;
use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
use tokio::task::JoinError;
use utils::http::json::display_serialize;
use crate::debug_dump;
use crate::safekeeper::ServerInfo;
use crate::safekeeper::Term;
@@ -54,15 +55,6 @@ fn get_conf(request: &Request<Body>) -> &SafeKeeperConf {
.as_ref()
}
/// Serialize through Display trait.
fn display_serialize<S, F>(z: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: Display,
{
s.serialize_str(&format!("{}", z))
}
/// Same as TermSwitchEntry, but serializes LSN using display serializer
/// in Postgres format, i.e. 0/FFFFFFFF. Used only for the API response.
#[derive(Debug, Serialize)]
@@ -119,12 +111,7 @@ async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body
);
check_permission(&request, Some(ttid.tenant_id))?;
let tli = GlobalTimelines::get(ttid)
// FIXME: Currently, the only errors from `GlobalTimelines::get` will be client errors
// because the provided timeline isn't there. However, the method can in theory change and
// fail from internal errors later. Remove this comment once it the method returns
// something other than `anyhow::Result`.
.map_err(ApiError::InternalServerError)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let (inmem, state) = tli.get_state();
let flush_lsn = tli.get_flush_lsn();
@@ -181,12 +168,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
.commit_lsn
.segment_lsn(server_info.wal_seg_size as usize)
});
tokio::task::spawn_blocking(move || {
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
})
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?
.map_err(ApiError::InternalServerError)?;
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -260,15 +244,7 @@ async fn record_safekeeper_info(mut request: Request<Body>) -> Result<Response<B
local_start_lsn: sk_info.local_start_lsn.0,
};
let tli = GlobalTimelines::get(ttid)
// `GlobalTimelines::get` returns an error when it can't find the timeline.
.with_context(|| {
format!(
"Couldn't get timeline {} for tenant {}",
ttid.timeline_id, ttid.tenant_id
)
})
.map_err(ApiError::NotFound)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
tli.record_safekeeper_info(&proto_sk_info)
.await
.map_err(ApiError::InternalServerError)?;
@@ -276,6 +252,69 @@ async fn record_safekeeper_info(mut request: Request<Body>) -> Result<Response<B
json_response(StatusCode::OK, ())
}
fn parse_kv_str<E: fmt::Display, T: FromStr<Err = E>>(k: &str, v: &str) -> Result<T, ApiError> {
v.parse()
.map_err(|e| ApiError::BadRequest(anyhow::anyhow!("cannot parse {k}: {e}")))
}
/// Dump debug info about all available safekeeper state.
async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
ensure_no_body(&mut request).await?;
let mut dump_all: Option<bool> = None;
let mut dump_control_file: Option<bool> = None;
let mut dump_memory: Option<bool> = None;
let mut dump_disk_content: Option<bool> = None;
let mut dump_term_history: Option<bool> = None;
let mut tenant_id: Option<TenantId> = None;
let mut timeline_id: Option<TimelineId> = None;
let query = request.uri().query().unwrap_or("");
let mut values = url::form_urlencoded::parse(query.as_bytes());
for (k, v) in &mut values {
match k.as_ref() {
"dump_all" => dump_all = Some(parse_kv_str(&k, &v)?),
"dump_control_file" => dump_control_file = Some(parse_kv_str(&k, &v)?),
"dump_memory" => dump_memory = Some(parse_kv_str(&k, &v)?),
"dump_disk_content" => dump_disk_content = Some(parse_kv_str(&k, &v)?),
"dump_term_history" => dump_term_history = Some(parse_kv_str(&k, &v)?),
"tenant_id" => tenant_id = Some(parse_kv_str(&k, &v)?),
"timeline_id" => timeline_id = Some(parse_kv_str(&k, &v)?),
_ => Err(ApiError::BadRequest(anyhow::anyhow!(
"Unknown query parameter: {}",
k
)))?,
}
}
let dump_all = dump_all.unwrap_or(false);
let dump_control_file = dump_control_file.unwrap_or(dump_all);
let dump_memory = dump_memory.unwrap_or(dump_all);
let dump_disk_content = dump_disk_content.unwrap_or(dump_all);
let dump_term_history = dump_term_history.unwrap_or(true);
let args = debug_dump::Args {
dump_all,
dump_control_file,
dump_memory,
dump_disk_content,
dump_term_history,
tenant_id,
timeline_id,
};
let resp = tokio::task::spawn_blocking(move || {
debug_dump::build(args).map_err(ApiError::InternalServerError)
})
.await
.map_err(|e: JoinError| ApiError::InternalServerError(e.into()))??;
// TODO: use streaming response
json_response(StatusCode::OK, resp)
}
/// Safekeeper http router.
pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
@@ -316,6 +355,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
"/v1/record_safekeeper_info/:tenant_id/:timeline_id",
record_safekeeper_info,
)
.get("/v1/debug_dump", dump_debug_handler)
}
#[cfg(test)]

View File

@@ -10,10 +10,10 @@ use std::sync::Arc;
use anyhow::Context;
use bytes::Bytes;
use postgres_backend::QueryError;
use serde::{Deserialize, Serialize};
use tracing::*;
use utils::id::TenantTimelineId;
use utils::postgres_backend_async::QueryError;
use crate::handler::SafekeeperPostgresHandler;
use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
@@ -23,29 +23,30 @@ use crate::safekeeper::{
use crate::safekeeper::{SafeKeeperState, Term, TermHistory, TermSwitchEntry};
use crate::timeline::Timeline;
use crate::GlobalTimelines;
use postgres_backend::PostgresBackend;
use postgres_ffi::encode_logical_message;
use postgres_ffi::WAL_SEGMENT_SIZE;
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
use utils::lsn::Lsn;
#[derive(Serialize, Deserialize, Debug)]
pub struct AppendLogicalMessage {
// prefix and message to build LogicalMessage
lm_prefix: String,
lm_message: String,
pub lm_prefix: String,
pub lm_message: String,
// if true, commit_lsn will match flush_lsn after append
set_commit_lsn: bool,
pub set_commit_lsn: bool,
// if true, ProposerElected will be sent before append
send_proposer_elected: bool,
pub send_proposer_elected: bool,
// fields from AppendRequestHeader
term: Term,
epoch_start_lsn: Lsn,
begin_lsn: Lsn,
truncate_lsn: Lsn,
pg_version: u32,
pub term: Term,
pub epoch_start_lsn: Lsn,
pub begin_lsn: Lsn,
pub truncate_lsn: Lsn,
pub pg_version: u32,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -59,7 +60,7 @@ struct AppendResult {
/// Handles command to craft logical message WAL record with given
/// content, and then append it with specified term and lsn. This
/// function is used to test safekeepers in different scenarios.
pub fn handle_json_ctrl(
pub async fn handle_json_ctrl(
spg: &SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
append_request: &AppendLogicalMessage,
@@ -67,7 +68,7 @@ pub fn handle_json_ctrl(
info!("JSON_CTRL request: {append_request:?}");
// need to init safekeeper state before AppendRequest
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version)?;
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version).await?;
// if send_proposer_elected is true, we need to update local history
if append_request.send_proposer_elected {
@@ -89,13 +90,16 @@ pub fn handle_json_ctrl(
..Default::default()
}]))?
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
.write_message_noflush(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
Ok(())
}
/// Prepare safekeeper to process append requests without crashes,
/// by sending ProposerGreeting with default server.wal_seg_size.
fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result<Arc<Timeline>> {
async fn prepare_safekeeper(
ttid: TenantTimelineId,
pg_version: u32,
) -> anyhow::Result<Arc<Timeline>> {
GlobalTimelines::create(
ttid,
ServerInfo {
@@ -106,6 +110,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result
Lsn::INVALID,
Lsn::INVALID,
)
.await
}
fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::Result<()> {
@@ -128,15 +133,15 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::R
}
#[derive(Debug, Serialize, Deserialize)]
struct InsertedWAL {
pub struct InsertedWAL {
begin_lsn: Lsn,
end_lsn: Lsn,
pub end_lsn: Lsn,
append_response: AppendResponse,
}
/// Extend local WAL with new LogicalMessage record. To do that,
/// create AppendRequest with new WAL and pass it to safekeeper.
fn append_logical_message(
pub fn append_logical_message(
tli: &Arc<Timeline>,
msg: &AppendLogicalMessage,
) -> anyhow::Result<InsertedWAL> {

View File

@@ -1,8 +1,8 @@
use storage_broker::Uri;
//
use remote_storage::RemoteStorageConfig;
use std::path::PathBuf;
use std::time::Duration;
use storage_broker::Uri;
use utils::id::{NodeId, TenantId, TenantTimelineId};
@@ -10,6 +10,7 @@ mod auth;
pub mod broker;
pub mod control_file;
pub mod control_file_upgrade;
pub mod debug_dump;
pub mod handler;
pub mod http;
pub mod json_ctrl;

View File

@@ -2,72 +2,134 @@
//! Gets messages from the network, passes them down to consensus module and
//! sends replies back.
use anyhow::anyhow;
use anyhow::Context;
use bytes::BytesMut;
use tracing::*;
use utils::lsn::Lsn;
use utils::postgres_backend_async::QueryError;
use crate::safekeeper::ServerInfo;
use crate::timeline::Timeline;
use crate::GlobalTimelines;
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use std::thread;
use crate::handler::SafekeeperPostgresHandler;
use crate::safekeeper::AcceptorProposerMessage;
use crate::safekeeper::ProposerAcceptorMessage;
use crate::safekeeper::ServerInfo;
use crate::timeline::Timeline;
use crate::wal_service::ConnectionId;
use crate::GlobalTimelines;
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use postgres_backend::CopyStreamHandlerEnd;
use postgres_backend::PostgresBackend;
use postgres_backend::PostgresBackendReader;
use postgres_backend::QueryError;
use pq_proto::BeMessage;
use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::task::spawn_blocking;
use tracing::*;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use crate::handler::SafekeeperPostgresHandler;
use pq_proto::{BeMessage, FeMessage};
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
const MSG_QUEUE_SIZE: usize = 256;
const REPLY_QUEUE_SIZE: usize = 16;
pub struct ReceiveWalConn<'pg> {
/// Postgres connection
pg_backend: &'pg mut PostgresBackend,
/// The cached result of `pg_backend.socket().peer_addr()` (roughly)
peer_addr: SocketAddr,
}
impl<'pg> ReceiveWalConn<'pg> {
pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> {
let peer_addr = *pg.get_peer_addr();
ReceiveWalConn {
pg_backend: pg,
peer_addr,
impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_wal_push_guts handling result. Error is
/// handled here while we're still in walreceiver ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_wal_push(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
// Log the result and probably send it to the client, closing the stream.
pgb.handle_copy_stream_end(end).await;
}
}
// Send message to the postgres
fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> {
let mut buf = BytesMut::with_capacity(128);
msg.serialize(&mut buf)?;
self.pg_backend.write_message(&BeMessage::CopyData(&buf))?;
Ok(())
}
/// Receive WAL from wal_proposer
pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> {
let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered();
pub async fn handle_start_wal_push_guts(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), CopyStreamHandlerEnd> {
// Notify the libpq client that it's allowed to send `CopyData` messages
self.pg_backend
.write_message(&BeMessage::CopyBothResponse)?;
pgb.write_message(&BeMessage::CopyBothResponse).await?;
let r = self
.pg_backend
.take_stream_in()
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
let mut poll_reader = ProposerPollStream::new(r)?;
// Experiments [1] confirm that doing network IO in one (this) thread and
// processing with disc IO in another significantly improves
// performance; we spawn off WalAcceptor thread for message processing
// to this end.
//
// [1] https://github.com/neondatabase/neon/pull/1318
let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE);
let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE);
let mut acceptor_handle: Option<JoinHandle<anyhow::Result<()>>> = None;
// Receive information about server
let next_msg = poll_reader.recv_msg()?;
// Concurrently receive and send data; replies are not synchronized with
// sends, so this avoids deadlocks.
let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?;
let peer_addr = *pgb.get_peer_addr();
let network_reader = NetworkReader {
ttid: self.ttid,
conn_id: self.conn_id,
pgb_reader: &mut pgb_reader,
peer_addr,
acceptor_handle: &mut acceptor_handle,
};
let res = tokio::select! {
// todo: add read|write .context to these errors
r = network_reader.run(msg_tx, msg_rx, reply_tx) => r,
r = network_write(pgb, reply_rx) => r,
};
// Join pg backend back.
pgb.unsplit(pgb_reader)?;
// Join the spawned WalAcceptor. At this point chans to/from it passed
// to network routines are dropped, so it will exit as soon as it
// touches them.
match acceptor_handle {
None => {
// failed even before spawning; read_network should have error
Err(res.expect_err("no error with WalAcceptor not spawn"))
}
Some(handle) => {
let wal_acceptor_res = handle.join();
// If there was any network error, return it.
res?;
// Otherwise, WalAcceptor thread must have errored.
match wal_acceptor_res {
Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination
Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))),
Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!(
"WalAcceptor thread panicked",
))),
}
}
}
}
}
struct NetworkReader<'a> {
ttid: TenantTimelineId,
conn_id: ConnectionId,
pgb_reader: &'a mut PostgresBackendReader,
peer_addr: SocketAddr,
// WalAcceptor is spawned when we learn server info from walproposer and
// create timeline; handle is put here.
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
}
impl<'a> NetworkReader<'a> {
async fn run(
self,
msg_tx: Sender<ProposerAcceptorMessage>,
msg_rx: Receiver<ProposerAcceptorMessage>,
reply_tx: Sender<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
// Receive information about server to create timeline, if not yet.
let next_msg = read_message(self.pgb_reader).await?;
let tli = match next_msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
info!(
@@ -79,127 +141,158 @@ impl<'pg> ReceiveWalConn<'pg> {
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
GlobalTimelines::create(self.ttid, server_info, Lsn::INVALID, Lsn::INVALID).await?
}
_ => {
return Err(QueryError::Other(anyhow::anyhow!(
return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!(
"unexpected message {next_msg:?} instead of greeting"
)))
}
};
let mut next_msg = Some(next_msg);
*self.acceptor_handle = Some(
WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx, self.conn_id)
.context("spawn WalAcceptor thread")?,
);
let mut first_time_through = true;
let mut _guard: Option<ComputeConnectionGuard> = None;
loop {
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
// poll AppendRequest's without blocking and write WAL to disk without flushing,
// while it's readily available
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
// Forward all messages to WalAcceptor
read_network_loop(self.pgb_reader, msg_tx, next_msg).await
}
}
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
/// Read next message from walproposer.
/// TODO: Return Ok(None) on graceful termination.
async fn read_message(
pgb_reader: &mut PostgresBackendReader,
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
let copy_data = pgb_reader.read_copy_message().await?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
Ok(msg)
}
next_msg = poll_reader.poll_msg();
}
async fn read_network_loop(
pgb_reader: &mut PostgresBackendReader,
msg_tx: Sender<ProposerAcceptorMessage>,
mut next_msg: ProposerAcceptorMessage,
) -> Result<(), CopyStreamHandlerEnd> {
loop {
if msg_tx.send(next_msg).await.is_err() {
return Ok(()); // chan closed, WalAcceptor terminated
}
next_msg = read_message(pgb_reader).await?;
}
}
// flush all written WAL to the disk
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
} else if let Some(msg) = next_msg.take() {
// process other message
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
}
if first_time_through {
// Register the connection and defer unregister. Do that only
// after processing first message, as it sets wal_seg_size,
// wanted by many.
tli.on_compute_connect()?;
_guard = Some(ComputeConnectionGuard {
timeline: Arc::clone(&tli),
});
first_time_through = false;
}
// blocking wait for the next message
if next_msg.is_none() {
next_msg = Some(poll_reader.recv_msg()?);
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
/// tell the error.
async fn network_write(
pgb_writer: &mut PostgresBackend,
mut reply_rx: Receiver<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
let mut buf = BytesMut::with_capacity(128);
loop {
match reply_rx.recv().await {
Some(msg) => {
buf.clear();
msg.serialize(&mut buf)?;
pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?;
}
None => return Ok(()), // chan closed, WalAcceptor terminated
}
}
}
struct ProposerPollStream {
/// Takes messages from msg_rx, processes and pushes replies to reply_tx.
struct WalAcceptor {
tli: Arc<Timeline>,
msg_rx: Receiver<ProposerAcceptorMessage>,
read_thread: Option<thread::JoinHandle<Result<(), QueryError>>>,
reply_tx: Sender<AcceptorProposerMessage>,
}
impl ProposerPollStream {
fn new(mut r: ReadStream) -> anyhow::Result<Self> {
let (msg_tx, msg_rx) = channel();
impl WalAcceptor {
/// Spawn thread with WalAcceptor running, return handle to it.
fn spawn(
tli: Arc<Timeline>,
msg_rx: Receiver<ProposerAcceptorMessage>,
reply_tx: Sender<AcceptorProposerMessage>,
conn_id: ConnectionId,
) -> anyhow::Result<JoinHandle<anyhow::Result<()>>> {
let thread_name = format!("WAL acceptor {}", tli.ttid);
thread::Builder::new()
.name(thread_name)
.spawn(move || -> anyhow::Result<()> {
let mut wa = WalAcceptor {
tli,
msg_rx,
reply_tx,
};
let read_thread = thread::Builder::new()
.name("Read WAL thread".into())
.spawn(move || -> Result<(), QueryError> {
loop {
let copy_data = match FeMessage::read(&mut r)? {
Some(FeMessage::CopyData(bytes)) => Ok(bytes),
Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
"expected `CopyData` message, found {msg:?}"
))),
None => Err(QueryError::from(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"walproposer closed the connection",
))),
}?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
msg_tx
.send(msg)
.context("Failed to send the proposer message")?;
}
// msg_tx will be dropped here, this will also close msg_rx
})?;
Ok(Self {
msg_rx,
read_thread: Some(read_thread),
})
let span_ttid = wa.tli.ttid; // satisfy borrow checker
runtime.block_on(
wa.run()
.instrument(info_span!("WAL acceptor", cid = %conn_id, ttid = %span_ttid)),
)
})
.map_err(anyhow::Error::from)
}
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
self.msg_rx.recv().map_err(|_| {
// return error from the read thread
let res = match self.read_thread.take() {
Some(thread) => thread.join(),
None => return QueryError::Other(anyhow::anyhow!("read thread is gone")),
};
/// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed;
/// it must mean that network thread terminated.
async fn run(&mut self) -> anyhow::Result<()> {
// Register the connection and defer unregister.
self.tli.on_compute_connect().await?;
let _guard = ComputeConnectionGuard {
timeline: Arc::clone(&self.tli),
};
match res {
Ok(Ok(())) => {
QueryError::Other(anyhow::anyhow!("unexpected result from read thread"))
}
Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")),
Ok(Err(err)) => err,
let mut next_msg: ProposerAcceptorMessage;
loop {
let opt_msg = self.msg_rx.recv().await;
if opt_msg.is_none() {
return Ok(()); // chan closed, streaming terminated
}
})
}
next_msg = opt_msg.unwrap();
fn poll_msg(&mut self) -> Option<ProposerAcceptorMessage> {
let res = self.msg_rx.try_recv();
if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) {
// loop through AppendRequest's while it's readily available to
// write as many WAL as possible without fsyncing
while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg {
let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
match res {
Err(_) => None,
Ok(msg) => Some(msg),
if let Some(reply) = self.tli.process_msg(&noflush_msg)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
}
match self.msg_rx.try_recv() {
Ok(msg) => next_msg = msg,
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated
}
}
// flush all written WAL to the disk
if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
}
} else {
// process message other than AppendRequest
if let Some(reply) = self.tli.process_msg(&next_msg)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
}
}
}
}
}
@@ -210,8 +303,13 @@ struct ComputeConnectionGuard {
impl Drop for ComputeConnectionGuard {
fn drop(&mut self) {
if let Err(e) = self.timeline.on_compute_disconnect() {
error!("failed to unregister compute connection: {}", e);
}
let tli = self.timeline.clone();
// tokio forbids to call blocking_send inside the runtime, and see
// comments in on_compute_disconnect why we call blocking_send.
spawn_blocking(move || {
if let Err(e) = tli.on_compute_disconnect() {
error!("failed to unregister compute connection: {}", e);
}
});
}
}

View File

@@ -191,7 +191,8 @@ pub struct SafeKeeperState {
/// Minimal LSN which may be needed for recovery of some safekeeper (end_lsn
/// of last record streamed to everyone). Persisting it helps skipping
/// recovery in walproposer, generally we compute it from peers. In
/// walproposer proto called 'truncate_lsn'.
/// walproposer proto called 'truncate_lsn'. Updates are currently drived
/// only by walproposer.
pub peer_horizon_lsn: Lsn,
/// LSN of the oldest known checkpoint made by pageserver and successfully
/// pushed to s3. We don't remove WAL beyond it. Persisted only for
@@ -204,7 +205,7 @@ pub struct SafeKeeperState {
pub peers: PersistedPeers,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize)]
// In memory safekeeper state. Fields mirror ones in `SafeKeeperState`; values
// are not flushed yet.
pub struct SafekeeperMemState {
@@ -212,6 +213,7 @@ pub struct SafekeeperMemState {
pub backup_lsn: Lsn,
pub peer_horizon_lsn: Lsn,
pub remote_consistent_lsn: Lsn,
#[serde(with = "hex")]
pub proposer_uuid: PgUuid,
}
@@ -486,7 +488,7 @@ impl AcceptorProposerMessage {
buf.put_u64_le(msg.hs_feedback.xmin);
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
msg.pageserver_feedback.serialize(buf)?
msg.pageserver_feedback.serialize(buf);
}
}
@@ -681,7 +683,7 @@ where
term: self.state.acceptor_state.term,
vote_given: false as u64,
flush_lsn: self.flush_lsn(),
truncate_lsn: self.state.peer_horizon_lsn,
truncate_lsn: self.inmem.peer_horizon_lsn,
term_history: self.get_term_history(),
timeline_start_lsn: self.state.timeline_start_lsn,
};
@@ -877,7 +879,13 @@ where
if msg.h.commit_lsn != Lsn(0) {
self.update_commit_lsn(msg.h.commit_lsn)?;
}
self.inmem.peer_horizon_lsn = msg.h.truncate_lsn;
// Value calculated by walproposer can always lag:
// - safekeepers can forget inmem value and send to proposer lower
// persisted one on restart;
// - if we make safekeepers always send persistent value,
// any compute restart would pull it down.
// Thus, take max before adopting.
self.inmem.peer_horizon_lsn = max(self.inmem.peer_horizon_lsn, msg.h.truncate_lsn);
// Update truncate and commit LSN in control file.
// To avoid negative impact on performance of extra fsync, do it only

View File

@@ -5,24 +5,22 @@ use crate::handler::SafekeeperPostgresHandler;
use crate::timeline::{ReplicaState, Timeline};
use crate::wal_storage::WalReader;
use crate::GlobalTimelines;
use anyhow::Context;
use anyhow::Context as AnyhowContext;
use bytes::Bytes;
use postgres_backend::PostgresBackend;
use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError};
use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::net::Shutdown;
use std::str;
use std::sync::Arc;
use std::time::Duration;
use std::{io, str, thread};
use utils::postgres_backend_async::QueryError;
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use tokio::sync::watch::Receiver;
use tokio::time::timeout;
use tracing::*;
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
use utils::{bin_ser::BeSer, lsn::Lsn};
// See: https://www.postgresql.org/docs/13/protocol-replication.html
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
@@ -60,13 +58,6 @@ pub struct StandbyReply {
pub reply_requested: bool,
}
/// A network connection that's speaking the replication protocol.
pub struct ReplicationConn {
/// This is an `Option` because we will spawn a background thread that will
/// `take` it from us.
stream_in: Option<ReadStream>,
}
/// Scope guard to unregister replication connection from timeline
struct ReplicationConnGuard {
replica: usize, // replica internal ID assigned by timeline
@@ -79,230 +70,275 @@ impl Drop for ReplicationConnGuard {
}
}
impl ReplicationConn {
/// Create a new `ReplicationConn`
pub fn new(pgb: &mut PostgresBackend) -> Self {
Self {
stream_in: pgb.take_stream_in(),
impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_replication_guts handling result. Error is
/// handled here while we're still in walsender ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_replication(
&mut self,
pgb: &mut PostgresBackend,
start_pos: Lsn,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
// Log the result and probably send it to the client, closing the stream.
pgb.handle_copy_stream_end(end).await;
}
}
/// Handle incoming messages from the network.
/// This is spawned into the background by `handle_start_replication`.
fn background_thread(
mut stream_in: ReadStream,
replica_guard: Arc<ReplicationConnGuard>,
) -> anyhow::Result<()> {
let replica_id = replica_guard.replica;
let timeline = &replica_guard.timeline;
let mut state = ReplicaState::new();
// Wait for replica's feedback.
while let Some(msg) = FeMessage::read(&mut stream_in)? {
match &msg {
FeMessage::CopyData(m) => {
// There's three possible data messages that the client is supposed to send here:
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
match m.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
.context("failed to deserialize HotStandbyFeedback")?;
timeline.update_replica_state(replica_id, state);
}
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply = StandbyReply::des(&m[1..])
.context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently this is not implemented, so this message is ignored.
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
// timeline.update_replica_state(replica_id, Some(state));
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ReplicationFeedback::parse(buf);
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
state.pageserver_feedback = Some(reply);
timeline.update_replica_state(replica_id, state);
}
_ => warn!("unexpected message {:?}", msg),
}
}
FeMessage::Sync => {}
FeMessage::CopyFail => {
// Shutdown the connection, because rust-postgres client cannot be dropped
// when connection is alive.
let _ = stream_in.shutdown(Shutdown::Both);
anyhow::bail!("Copy failed");
}
_ => {
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
info!("unexpected message {:?}", msg);
}
}
}
Ok(())
}
///
/// Handle START_REPLICATION replication command
///
pub fn run(
pub async fn handle_start_replication_guts(
&mut self,
spg: &mut SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
mut start_pos: Lsn,
) -> Result<(), QueryError> {
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
let tli = GlobalTimelines::get(spg.ttid)?;
// spawn the background thread which receives HotStandbyFeedback messages.
let bg_timeline = Arc::clone(&tli);
let bg_stream_in = self.stream_in.take().unwrap();
let bg_timeline_id = spg.timeline_id.unwrap();
start_pos: Lsn,
) -> Result<(), CopyStreamHandlerEnd> {
let appname = self.appname.clone();
let tli =
GlobalTimelines::get(self.ttid).map_err(|e| CopyStreamHandlerEnd::Other(e.into()))?;
let state = ReplicaState::new();
// This replica_id is used below to check if it's time to stop replication.
let replica_id = bg_timeline.add_replica(state);
let replica_id = tli.add_replica(state);
// Use a guard object to remove our entry from the timeline, when the background
// thread and us have both finished using it.
let replica_guard = Arc::new(ReplicationConnGuard {
let _guard = Arc::new(ReplicationConnGuard {
replica: replica_id,
timeline: bg_timeline,
timeline: tli.clone(),
});
let bg_replica_guard = Arc::clone(&replica_guard);
// TODO: here we got two threads, one for writing WAL and one for receiving
// feedback. If one of them fails, we should shutdown the other one too.
let _ = thread::Builder::new()
.name("HotStandbyFeedback thread".into())
.spawn(move || {
let _enter =
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
error!("Replication background thread failed: {}", err);
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
None
};
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
info!(
"starting streaming from {:?} till {:?}",
start_pos, stop_pos
);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse).await?;
let (_, persisted_state) = tli.get_state();
let wal_reader = WalReader::new(
self.conf.workdir.clone(),
self.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
self.conf.wal_backup_enabled,
)?;
// Split to concurrently receive and send data; replies are generally
// not synchronized with sends, so this avoids deadlocks.
let reader = pgb.split().context("START_REPLICATION split")?;
let mut sender = WalSender {
pgb,
tli: tli.clone(),
appname,
start_pos,
end_pos,
stop_pos,
commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(),
replica_id,
wal_reader,
send_buf: [0; MAX_SEND_SIZE],
};
let mut reply_reader = ReplyReader {
reader,
tli,
replica_id,
feedback: ReplicaState::new(),
};
let res = tokio::select! {
// todo: add read|write .context to these errors
r = sender.run() => r,
r = reply_reader.run() => r,
};
// Join pg backend back.
pgb.unsplit(reply_reader.reader)?;
res
}
}
/// A half driving sending WAL.
struct WalSender<'a> {
pgb: &'a mut PostgresBackend,
tli: Arc<Timeline>,
appname: Option<String>,
// Position since which we are sending next chunk.
start_pos: Lsn,
// WAL up to this position is known to be locally available.
end_pos: Lsn,
// If present, terminate after reaching this position; used by walproposer
// in recovery.
stop_pos: Option<Lsn>,
commit_lsn_watch_rx: Receiver<Lsn>,
replica_id: usize,
wal_reader: WalReader,
// buffer for readling WAL into to send it
send_buf: [u8; MAX_SEND_SIZE],
}
impl WalSender<'_> {
/// Send WAL until
/// - an error occurs
/// - if we are streaming to walproposer, we've streamed until stop_pos
/// (recovery finished)
/// - receiver is caughtup and there is no computes
///
/// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ?
/// convenience.
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
// If we are streaming to walproposer, check it is time to stop.
if let Some(stop_pos) = self.stop_pos {
if self.start_pos >= stop_pos {
// recovery finished
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
"ending streaming to walproposer at {}, recovery finished",
self.start_pos
)));
}
})?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async move {
let (inmem_state, persisted_state) = tli.get_state();
// add persisted_state.timeline_start_lsn == Lsn(0) check
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
None
};
// Wait for the next portion if it is not there yet, or just
// update our end of WAL available for sending value, we
// communicate it to the receiver.
self.wait_wal().await?;
}
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
// try to send as much as available, capped by MAX_SEND_SIZE
let mut send_size = self
.end_pos
.checked_sub(self.start_pos)
.context("reading wal without waiting for it first")?
.0 as usize;
send_size = min(send_size, self.send_buf.len());
let send_buf = &mut self.send_buf[..send_size];
// read wal into buffer
send_size = self.wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
let mut wal_reader = WalReader::new(
spg.conf.workdir.clone(),
spg.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
spg.conf.wal_backup_enabled,
)?;
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
// watcher for commit_lsn updates
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
loop {
if let Some(stop_pos) = stop_pos {
if start_pos >= stop_pos {
break; /* recovery finished */
}
end_pos = stop_pos;
} else {
/* Wait until we have some data to stream */
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if tli.should_walsender_stop(replica_id) {
// Shut down, timeline is suspended.
return Err(QueryError::from(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("end streaming to {:?}", spg.appname),
)));
}
// timeout expired: request pageserver status
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))?;
continue;
}
}
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
// and send it
self.pgb
.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: self.start_pos.0,
wal_end: self.end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.context("Failed to send XLogData")?;
.await?;
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
trace!(
"sent {} bytes of WAL {}-{}",
send_size,
self.start_pos,
self.start_pos + send_size as u64
);
self.start_pos += send_size as u64;
}
}
/// wait until we have WAL to stream, sending keepalives and checking for
/// exit in the meanwhile
async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? {
self.end_pos = lsn;
return Ok(());
}
// Timed out waiting for WAL, check for termination and send KA
if self.tli.should_walsender_stop(self.replica_id) {
// Terminate if there is nothing more to send.
// TODO close the stream properly
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
self.appname, self.start_pos,
)));
}
self.pgb
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: self.end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.await?;
}
}
}
Ok(())
})
/// A half driving receiving replies.
struct ReplyReader {
reader: PostgresBackendReader,
tli: Arc<Timeline>,
replica_id: usize,
feedback: ReplicaState,
}
impl ReplyReader {
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
let msg = self.reader.read_copy_message().await?;
self.handle_feedback(&msg)?
}
}
fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> {
match msg.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..])
.context("failed to deserialize HotStandbyFeedback")?;
self.tli
.update_replica_state(self.replica_id, self.feedback);
}
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply =
StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently we just ignore this, tracking progress for them is not supported.
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// pageserver sends this.
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&msg[9..]);
let reply = ReplicationFeedback::parse(buf);
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
self.feedback.pageserver_feedback = Some(reply);
self.tli
.update_replica_state(self.replica_id, self.feedback);
}
_ => warn!("unexpected message {:?}", msg),
}
Ok(())
}
}
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
/// Wait until we have commit_lsn > lsn or timeout expires. Returns
/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
/// - Ok(None) if timeout expired;
/// - Err in case of error (if watch channel is in trouble, shouldn't happen).
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
let commit_lsn: Lsn = *rx.borrow();
if commit_lsn > lsn {

View File

@@ -1,10 +1,11 @@
//! This module implements Timeline lifecycle management and has all neccessary code
//! This module implements Timeline lifecycle management and has all necessary code
//! to glue together SafeKeeper and all other background services.
use anyhow::{bail, Result};
use anyhow::{anyhow, bail, Result};
use parking_lot::{Mutex, MutexGuard};
use postgres_ffi::XLogSegNo;
use pq_proto::ReplicationFeedback;
use serde::Serialize;
use std::cmp::{max, min};
use std::path::PathBuf;
use tokio::{
@@ -12,6 +13,7 @@ use tokio::{
time::Instant,
};
use tracing::*;
use utils::http::error::ApiError;
use utils::{
id::{NodeId, TenantTimelineId},
lsn::Lsn,
@@ -28,9 +30,9 @@ use crate::send_wal::HotStandbyFeedback;
use crate::{control_file, safekeeper::UNKNOWN_SERVER_VERSION};
use crate::metrics::FullTimelineInfo;
use crate::wal_storage;
use crate::wal_storage::Storage as wal_storage_iface;
use crate::SafeKeeperConf;
use crate::{debug_dump, wal_storage};
/// Things safekeeper should know about timeline state on peers.
#[derive(Debug, Clone)]
@@ -80,7 +82,7 @@ impl PeersInfo {
}
/// Replica status update + hot standby feedback
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Serialize)]
pub struct ReplicaState {
/// last known lsn received by replica
pub last_received_lsn: Lsn, // None means we don't know
@@ -355,6 +357,18 @@ pub enum TimelineError {
UninitialinzedPgVersion(TenantTimelineId),
}
// Convert to HTTP API error.
impl From<TimelineError> for ApiError {
fn from(te: TimelineError) -> ApiError {
match te {
TimelineError::NotFound(ttid) => {
ApiError::NotFound(anyhow!("timeline {} not found", ttid))
}
_ => ApiError::InternalServerError(anyhow!("{}", te)),
}
}
}
/// Timeline struct manages lifecycle (creation, deletion, restore) of a safekeeper timeline.
/// It also holds SharedState and provides mutually exclusive access to it.
pub struct Timeline {
@@ -381,7 +395,7 @@ pub struct Timeline {
cancellation_rx: watch::Receiver<bool>,
/// Directory where timeline state is stored.
timeline_dir: PathBuf,
pub timeline_dir: PathBuf,
}
impl Timeline {
@@ -518,7 +532,7 @@ impl Timeline {
/// Register compute connection, starting timeline-related activity if it is
/// not running yet.
pub fn on_compute_connect(&self) -> Result<()> {
pub async fn on_compute_connect(&self) -> Result<()> {
if self.is_cancelled() {
bail!(TimelineError::Cancelled(self.ttid));
}
@@ -532,7 +546,7 @@ impl Timeline {
// Wake up wal backup launcher, if offloading not started yet.
if is_wal_backup_action_pending {
// Can fail only if channel to a static thread got closed, which is not normal at all.
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
self.wal_backup_launcher_tx.send(self.ttid).await?;
}
Ok(())
}
@@ -549,6 +563,11 @@ impl Timeline {
// Wake up wal backup launcher, if it is time to stop the offloading.
if is_wal_backup_action_pending {
// Can fail only if channel to a static thread got closed, which is not normal at all.
//
// Note: this is blocking_send because on_compute_disconnect is called in Drop, there is
// no async Drop and we use current thread runtimes. With current thread rt spawning
// task in drop impl is racy, as thread along with runtime might finish before the task.
// This should be switched send.await when/if we go to full async.
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
}
Ok(())
@@ -588,38 +607,6 @@ impl Timeline {
self.write_shared_state().wal_backup_attend()
}
/// Returns full timeline info, required for the metrics. If the timeline is
/// not active, returns None instead.
pub fn info_for_metrics(&self) -> Option<FullTimelineInfo> {
if self.is_cancelled() {
return None;
}
let state = self.write_shared_state();
if state.active {
Some(FullTimelineInfo {
ttid: self.ttid,
replicas: state
.replicas
.iter()
.filter_map(|r| r.as_ref())
.copied()
.collect(),
wal_backup_active: state.wal_backup_active,
timeline_is_active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
persisted_state: state.sk.state.clone(),
flush_lsn: state.sk.wal_store.flush_lsn(),
wal_storage: state.sk.wal_store.get_metrics(),
})
} else {
None
}
}
/// Returns commit_lsn watch channel.
pub fn get_commit_lsn_watch_rx(&self) -> watch::Receiver<Lsn> {
self.commit_lsn_watch_rx.clone()
@@ -784,6 +771,62 @@ impl Timeline {
shared_state.last_removed_segno = horizon_segno;
Ok(())
}
/// Returns full timeline info, required for the metrics. If the timeline is
/// not active, returns None instead.
pub fn info_for_metrics(&self) -> Option<FullTimelineInfo> {
if self.is_cancelled() {
return None;
}
let state = self.write_shared_state();
if state.active {
Some(FullTimelineInfo {
ttid: self.ttid,
replicas: state
.replicas
.iter()
.filter_map(|r| r.as_ref())
.copied()
.collect(),
wal_backup_active: state.wal_backup_active,
timeline_is_active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
persisted_state: state.sk.state.clone(),
flush_lsn: state.sk.wal_store.flush_lsn(),
wal_storage: state.sk.wal_store.get_metrics(),
})
} else {
None
}
}
/// Returns in-memory timeline state to build a full debug dump.
pub fn memory_dump(&self) -> debug_dump::Memory {
let state = self.write_shared_state();
let (write_lsn, write_record_lsn, flush_lsn, file_open) =
state.sk.wal_store.internal_state();
debug_dump::Memory {
is_cancelled: self.is_cancelled(),
peers_info_len: state.peers_info.0.len(),
replicas: state.replicas.clone(),
wal_backup_active: state.wal_backup_active,
active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
write_lsn,
write_record_lsn,
flush_lsn,
file_open,
}
}
}
/// Deletes directory and it's contents. Returns false if directory does not exist.

View File

@@ -5,7 +5,7 @@
use crate::safekeeper::ServerInfo;
use crate::timeline::{Timeline, TimelineError};
use crate::SafeKeeperConf;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{bail, Context, Result};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
@@ -50,11 +50,11 @@ impl GlobalTimelinesState {
}
/// Get timeline from the map. Returns error if timeline doesn't exist.
fn get(&self, ttid: &TenantTimelineId) -> Result<Arc<Timeline>> {
fn get(&self, ttid: &TenantTimelineId) -> Result<Arc<Timeline>, TimelineError> {
self.timelines
.get(ttid)
.cloned()
.ok_or_else(|| anyhow!(TimelineError::NotFound(*ttid)))
.ok_or(TimelineError::NotFound(*ttid))
}
}
@@ -159,9 +159,19 @@ impl GlobalTimelines {
Ok(())
}
/// Get the number of timelines in the map.
pub fn timelines_count() -> usize {
TIMELINES_STATE.lock().unwrap().timelines.len()
}
/// Get the global safekeeper config.
pub fn get_global_config() -> SafeKeeperConf {
TIMELINES_STATE.lock().unwrap().get_conf().clone()
}
/// Create a new timeline with the given id. If the timeline already exists, returns
/// an existing timeline.
pub fn create(
pub async fn create(
ttid: TenantTimelineId,
server_info: ServerInfo,
commit_lsn: Lsn,
@@ -189,28 +199,20 @@ impl GlobalTimelines {
// Take a lock and finish the initialization holding this mutex. No other threads
// can interfere with creation after we will insert timeline into the map.
let mut shared_state = timeline.write_shared_state();
{
let mut shared_state = timeline.write_shared_state();
// We can get a race condition here in case of concurrent create calls, but only
// in theory. create() will return valid timeline on the next try.
TIMELINES_STATE
.lock()
.unwrap()
.try_insert(timeline.clone())?;
// We can get a race condition here in case of concurrent create calls, but only
// in theory. create() will return valid timeline on the next try.
TIMELINES_STATE
.lock()
.unwrap()
.try_insert(timeline.clone())?;
// Write the new timeline to the disk and start background workers.
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
// and the state on disk should remain unchanged.
match timeline.bootstrap(&mut shared_state) {
Ok(_) => {
// We are done with bootstrap, release the lock, return the timeline.
drop(shared_state);
timeline
.wal_backup_launcher_tx
.blocking_send(timeline.ttid)?;
Ok(timeline)
}
Err(e) => {
// Write the new timeline to the disk and start background workers.
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
// and the state on disk should remain unchanged.
if let Err(e) = timeline.bootstrap(&mut shared_state) {
// Note: the most likely reason for bootstrap failure is that the timeline
// directory already exists on disk. This happens when timeline is corrupted
// and wasn't loaded from disk on startup because of that. We want to preserve
@@ -222,29 +224,33 @@ impl GlobalTimelines {
// Timeline failed to bootstrap, it cannot be used. Remove it from the map.
TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid);
Err(e)
return Err(e);
}
// We are done with bootstrap, release the lock, return the timeline.
// {} block forces release before .await
}
timeline.wal_backup_launcher_tx.send(timeline.ttid).await?;
Ok(timeline)
}
/// Get a timeline from the global map. If it's not present, it doesn't exist on disk,
/// or was corrupted and couldn't be loaded on startup. Returned timeline is always valid,
/// i.e. loaded in memory and not cancelled.
pub fn get(ttid: TenantTimelineId) -> Result<Arc<Timeline>> {
pub fn get(ttid: TenantTimelineId) -> Result<Arc<Timeline>, TimelineError> {
let res = TIMELINES_STATE.lock().unwrap().get(&ttid);
match res {
Ok(tli) => {
if tli.is_cancelled() {
anyhow::bail!(TimelineError::Cancelled(ttid));
return Err(TimelineError::Cancelled(ttid));
}
Ok(tli)
}
Err(e) => Err(e),
_ => res,
}
}
/// Returns all timelines. This is used for background timeline proccesses.
/// Returns all timelines. This is used for background timeline processes.
pub fn get_all() -> Vec<Arc<Timeline>> {
let global_lock = TIMELINES_STATE.lock().unwrap();
global_lock

View File

@@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop(
.map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage"))
});
// Presense in this map means launcher is aware s3 offloading is needed for
// Presence in this map means launcher is aware s3 offloading is needed for
// the timeline, but task is started only if it makes sense for to offload
// from this safekeeper.
let mut tasks: HashMap<TenantTimelineId, WalBackupTimelineEntry> = HashMap::new();
@@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
pub async fn read_object(
file_path: &RemotePath,
offset: u64,
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
let storage = REMOTE_STORAGE
.get()
.context("Failed to get remote storage")?

View File

@@ -2,50 +2,70 @@
//! WAL service listens for client connections and
//! receive WAL from wal_proposer and send it to WAL receivers
//!
use regex::Regex;
use std::net::{TcpListener, TcpStream};
use std::thread;
use anyhow::{Context, Result};
use postgres_backend::QueryError;
use std::{future, thread};
use tokio::net::TcpStream;
use tracing::*;
use utils::postgres_backend_async::QueryError;
use crate::handler::SafekeeperPostgresHandler;
use crate::SafeKeeperConf;
use utils::postgres_backend::{AuthType, PostgresBackend};
use postgres_backend::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! {
loop {
match listener.accept() {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("create runtime")
// todo catch error in main thread
.expect("failed to create runtime");
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
runtime
.block_on(async move {
// Tokio's from_std won't do this for us, per its comment.
pg_listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
let mut connection_count: ConnectionCount = 0;
loop {
match listener.accept().await {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
let conn_id = issue_connection_id(&mut connection_count);
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf, conn_id) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
}
Err(e) => error!("Failed to accept connection: {}", e),
}
}
Err(e) => error!("Failed to accept connection: {}", e),
}
}
}
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
fn get_tid() -> u64 {
let tids = format!("{:?}", thread::current().id());
let r = Regex::new(r"ThreadId\((\d+)\)").unwrap();
let caps = r.captures(&tids).unwrap();
caps.get(1).unwrap().as_str().parse().unwrap()
#[allow(unreachable_code)] // hint compiler the closure return type
Ok::<(), anyhow::Error>(())
})
.expect("listener failed")
}
/// This is run by `thread_main` above, inside a background thread.
///
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
let _enter = info_span!("", tid = ?get_tid()).entered();
fn handle_socket(
socket: TcpStream,
conf: SafeKeeperConf,
conn_id: ConnectionId,
) -> Result<(), QueryError> {
let _enter = info_span!("", cid = %conn_id).entered();
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();
socket.set_nodelay(true)?;
@@ -53,10 +73,23 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
};
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
// libpq replication protocol between safekeeper and replicas/pagers
pgbackend.run(&mut conn_handler)?;
let mut conn_handler = SafekeeperPostgresHandler::new(conf, conn_id);
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
local.block_on(
&runtime,
pgbackend.run(&mut conn_handler, future::pending::<()>),
)?;
Ok(())
}
/// Unique WAL service connection ids are logged in spans for observability.
pub type ConnectionId = u32;
pub type ConnectionCount = u32;
pub fn issue_connection_id(count: &mut ConnectionCount) -> ConnectionId {
*count = count.wrapping_add(1);
*count
}

View File

@@ -165,6 +165,16 @@ impl PhysicalStorage {
})
}
/// Get all known state of the storage.
pub fn internal_state(&self) -> (Lsn, Lsn, Lsn, bool) {
(
self.write_lsn,
self.write_record_lsn,
self.flush_record_lsn,
self.file.is_some(),
)
}
/// Call fdatasync if config requires so.
fn fdatasync_file(&mut self, file: &mut File) -> Result<()> {
if !self.conf.no_sync {
@@ -461,7 +471,7 @@ pub struct WalReader {
timeline_dir: PathBuf,
wal_seg_size: usize,
pos: Lsn,
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
// S3 will be used to read WAL if LSN is not available locally
enable_remote_read: bool,
@@ -528,7 +538,7 @@ impl WalReader {
}
/// Open WAL segment at the current position of the reader.
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);

View File

@@ -1,126 +0,0 @@
import csv
import logging
import sys
import textwrap
import requests
import argparse
import json
class Client:
def __init__(self, endpoint) -> None:
self.endpoint = endpoint
def get(self, rel_url, **kwargs):
resp = requests.get(self.endpoint + rel_url, **kwargs)
try:
resp.raise_for_status()
except requests.exceptions.HTTPError:
print("API ERROR: " + resp.text)
raise
return resp.json()
def put(self, rel_url, **kwargs):
resp = requests.put(self.endpoint + rel_url, **kwargs)
try:
resp.raise_for_status()
except requests.exceptions.HTTPError:
print("API ERROR: " + resp.text)
raise
return resp.json()
class AppException(RuntimeError):
pass
def do_one(tenant, endpoint, merge_existing_with, check_tenant_exists=True):
global verbose
client = Client(endpoint)
if check_tenant_exists:
tenants = client.get(f"/v1/tenant")
matching_tenant = [ t for t in tenants if t['id'] == tenant ]
if len(matching_tenant) == 0:
raise AppException(f"no tenant {tenant} on pageserver {endpoint}")
elif len(matching_tenant) > 1:
raise AppException(f"multiple ({len(matching_tenant)}) tenants with id {tenant} on pageserver {endpoint}")
else:
pass
config = client.get(f"/v1/tenant/{tenant}/config")
def comparable_json(obj):
j = json.dumps(obj, indent=' ', sort_keys=True)
return textwrap.indent(j, ' ')
if verbose:
before = comparable_json(config)
print(f"BEFORE:\n{before}")
overrides = config['tenant_specific_overrides']
updated = {**overrides, **merge_existing_with}
client.put("/v1/tenant/config", json={**updated, "tenant_id": tenant})
if verbose:
new_config = client.get(f"/v1/tenant/{tenant}/config")
after = comparable_json(new_config)
print(f"AFTER:\n{after}")
def do_csv(csv_file, merge_existing_with):
succeeded = []
failed = []
for n, line in enumerate(csv.reader(csv_file)):
if n == 0:
# skip header row
continue
if len(line) != 2:
logging.warn(f"skipping line {n+1}: {line}")
continue
tenant_id = line[0]
pageserver = line[1]
try:
do_one(tenant_id, f"http://{pageserver}:9898", merge_existing_with, check_tenant_exists=False)
logging.info(f"succeeded to configure tenant {tenant_id}")
succeeded += [tenant_id]
except Exception as e:
logging.exception(f"failed to configure tenant {tenant_id}")
failed += [tenant_id]
print(json.dumps({
"succeeded": succeeded,
"failed": failed,
}, indent=' ', sort_keys=True))
verbose = False
def main():
global verbose
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
p = argparse.ArgumentParser()
p.add_argument("--merge-existing-with", type=str)
p.add_argument("--verbose", action='store_true')
subcommands = p.add_subparsers(dest="subcommand")
one_tenant_parser = subcommands.add_parser("one", help='change config of one tenant, specified via CLI flags')
one_tenant_parser.add_argument("--tenant", required=True)
one_tenant_parser.add_argument("--endpoint", type=str, default='http://localhost:9898')
csv_parser = subcommands.add_parser("csv", help='batch reconfigure tenants specified in a csv file')
csv_parser.add_argument('csv_file', type=argparse.FileType())
args = p.parse_args()
verbose = args.verbose
merge_existing_with = {}
if args.merge_existing_with is not None:
merge_existing_with = json.loads(args.merge_existing_with)
assert isinstance(merge_existing_with, dict)
({
'one': lambda: do_one(args.tenant, args.endpoint, merge_existing_with),
'csv': lambda: do_csv(args.csv_file, merge_existing_with),
}[args.subcommand])()
if __name__ == '__main__':
main()

View File

@@ -354,29 +354,26 @@ class NeonBenchmarker:
"""
Fetch the "cumulative # of bytes written" metric from the pageserver
"""
metric_name = r'libmetrics_disk_io_bytes_total{io_operation="write"}'
return self.get_int_counter_value(pageserver, metric_name)
return self.get_int_counter_value(
pageserver, "libmetrics_disk_io_bytes_total", {"io_operation": "write"}
)
def get_peak_mem(self, pageserver: NeonPageserver) -> int:
"""
Fetch the "maxrss" metric from the pageserver
"""
metric_name = r"libmetrics_maxrss_kb"
return self.get_int_counter_value(pageserver, metric_name)
return self.get_int_counter_value(pageserver, "libmetrics_maxrss_kb")
def get_int_counter_value(self, pageserver: NeonPageserver, metric_name: str) -> int:
def get_int_counter_value(
self,
pageserver: NeonPageserver,
metric_name: str,
label_filters: Optional[Dict[str, str]] = None,
) -> int:
"""Fetch the value of given int counter from pageserver metrics."""
# TODO: If we start to collect more of the prometheus metrics in the
# performance test suite like this, we should refactor this to load and
# parse all the metrics into a more convenient structure in one go.
#
# The metric should be an integer, as it's a number of bytes. But in general
# all prometheus metrics are floats. So to be pedantic, read it as a float
# and round to integer.
all_metrics = pageserver.http_client().get_metrics()
matches = re.search(rf"^{metric_name} (\S+)$", all_metrics, re.MULTILINE)
assert matches, f"metric {metric_name} not found"
return int(round(float(matches.group(1))))
sample = all_metrics.query_one(metric_name, label_filters)
return int(round(sample.value))
def get_timeline_size(
self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId

View File

@@ -144,12 +144,12 @@ class NeonCompare(PgCompare):
"size", timeline_size / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER
)
params = f'{{tenant_id="{self.tenant}",timeline_id="{self.timeline}"}}'
metric_filters = {"tenant_id": str(self.tenant), "timeline_id": str(self.timeline)}
total_files = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_created_persistent_files_total" + params
self.env.pageserver, "pageserver_created_persistent_files_total", metric_filters
)
total_bytes = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_written_persistent_bytes_total" + params
self.env.pageserver, "pageserver_written_persistent_bytes_total", metric_filters
)
self.zenbenchmark.record(
"data_uploaded", total_bytes / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER

View File

@@ -13,7 +13,8 @@ class Metrics:
self.metrics = defaultdict(list)
self.name = name
def query_all(self, name: str, filter: Dict[str, str]) -> List[Sample]:
def query_all(self, name: str, filter: Optional[Dict[str, str]] = None) -> List[Sample]:
filter = filter or {}
res = []
for sample in self.metrics[name]:
try:

View File

@@ -14,6 +14,7 @@ import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from contextlib import closing, contextmanager
from dataclasses import dataclass, field
from enum import Flag, auto
@@ -28,7 +29,6 @@ import asyncpg
import backoff # type: ignore
import boto3
import jwt
import prometheus_client
import psycopg2
import pytest
import requests
@@ -36,7 +36,7 @@ from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.metrics import Metrics, parse_metrics
from fixtures.types import Lsn, TenantId, TimelineId
from fixtures.utils import (
ATTACHMENT_NAME_REGEX,
@@ -45,7 +45,6 @@ from fixtures.utils import (
get_self_dir,
subprocess_capture,
)
from prometheus_client.parser import text_string_to_metric_families
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
@@ -1436,22 +1435,27 @@ class PageserverHttpClient(requests.Session):
assert completed["successful_download_count"] > 0
return completed
def get_metrics(self) -> str:
def get_metrics_str(self) -> str:
"""You probably want to use get_metrics() instead."""
res = self.get(f"http://localhost:{self.port}/metrics")
self.verbose_error(res)
return res.text
def get_timeline_metric(self, tenant_id: TenantId, timeline_id: TimelineId, metric_name: str):
raw = self.get_metrics()
family: List[prometheus_client.Metric] = list(text_string_to_metric_families(raw))
[metric] = [m for m in family if m.name == metric_name]
[sample] = [
s
for s in metric.samples
if s.labels["tenant_id"] == str(tenant_id)
and s.labels["timeline_id"] == str(timeline_id)
]
return sample.value
def get_metrics(self) -> Metrics:
res = self.get_metrics_str()
return parse_metrics(res)
def get_timeline_metric(
self, tenant_id: TenantId, timeline_id: TimelineId, metric_name: str
) -> float:
metrics = self.get_metrics()
return metrics.query_one(
metric_name,
filter={
"tenant_id": str(tenant_id),
"timeline_id": str(timeline_id),
},
).value
def get_remote_timeline_client_metric(
self,
@@ -1461,7 +1465,7 @@ class PageserverHttpClient(requests.Session):
file_kind: str,
op_kind: str,
) -> Optional[float]:
metrics = parse_metrics(self.get_metrics(), "pageserver")
metrics = self.get_metrics()
matches = metrics.query_all(
name=metric_name,
filter={
@@ -1480,14 +1484,16 @@ class PageserverHttpClient(requests.Session):
assert len(matches) < 2, "above filter should uniquely identify metric"
return value
def get_metric_value(self, name: str) -> Optional[str]:
def get_metric_value(
self, name: str, filter: Optional[Dict[str, str]] = None
) -> Optional[float]:
metrics = self.get_metrics()
relevant = [line for line in metrics.splitlines() if line.startswith(name)]
if len(relevant) == 0:
results = metrics.query_all(name, filter=filter)
if not results:
log.info(f'could not find metric "{name}"')
return None
assert len(relevant) == 1
return relevant[0].lstrip(name).strip()
assert len(results) == 1, f"metric {name} with given filters is not unique, got: {results}"
return results[0].value
def layer_map_info(
self,
@@ -1516,6 +1522,11 @@ class PageserverHttpClient(requests.Session):
assert res.status_code == 200
def evict_all_layers(self, tenant_id: TenantId, timeline_id: TimelineId):
info = self.layer_map_info(tenant_id, timeline_id)
for layer in info.historic_layers:
self.evict_layer(tenant_id, timeline_id, layer.layer_file_name)
@dataclass
class TenantConfig:
@@ -1551,6 +1562,14 @@ class LayerMapInfo:
return info
def kind_count(self) -> Dict[str, int]:
counts: Dict[str, int] = defaultdict(int)
for inmem_layer in self.in_memory_layers:
counts[inmem_layer.kind] += 1
for hist_layer in self.historic_layers:
counts[hist_layer.kind] += 1
return counts
@dataclass
class InMemoryLayerInfo:
@@ -1567,7 +1586,7 @@ class InMemoryLayerInfo:
)
@dataclass
@dataclass(frozen=True)
class HistoricLayerInfo:
kind: str
layer_file_name: str
@@ -1669,7 +1688,7 @@ class AbstractNeonCli(abc.ABC):
timeout=timeout,
)
if not res.returncode:
log.info(f"Run success: {res.stdout}")
log.info(f"Run {res.args} success: {res.stdout}")
elif check_return_code:
# this way command output will be in recorded and shown in CI in failure message
msg = f"""\
@@ -2049,8 +2068,10 @@ class NeonPageserver(PgProtocol):
".*Connection aborted: connection error: error communicating with the server: Broken pipe.*",
".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*",
".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*",
# FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected
".*Connection aborted: connection error: unexpected message from server*",
".*kill_and_wait_impl.*: wait successful.*",
".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*",
".*Replication stream finished: db error:.*ending streaming to Some*",
".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down
".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down
# safekeeper connection can fail with this, in the window between timeline creation
@@ -2988,6 +3009,13 @@ class SafekeeperHttpClient(requests.Session):
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def debug_dump(self, params: Dict[str, str] = {}) -> Dict[str, Any]:
res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def timeline_create(
self, tenant_id: TenantId, timeline_id: TimelineId, pg_version: int, commit_lsn: Lsn
):
@@ -3463,6 +3491,14 @@ def wait_for_last_flush_lsn(
return wait_for_last_record_lsn(env.pageserver.http_client(), tenant, timeline, last_flush_lsn)
def wait_for_wal_insert_lsn(
env: NeonEnv, pg: Postgres, tenant: TenantId, timeline: TimelineId
) -> Lsn:
"""Wait for pageserver to catch up the latest flush LSN, returns the last observed lsn."""
last_flush_lsn = Lsn(pg.safe_psql("SELECT pg_current_wal_insert_lsn()")[0][0])
return wait_for_last_record_lsn(env.pageserver.http_client(), tenant, timeline, last_flush_lsn)
def fork_at_current_lsn(
env: NeonEnv,
pg: Postgres,
@@ -3508,3 +3544,23 @@ def wait_for_sk_commit_lsn_to_reach_remote_storage(
ps_http.timeline_checkpoint(tenant_id, timeline_id)
wait_for_upload(ps_http, tenant_id, timeline_id, lsn)
return lsn
def wait_for_upload_queue_empty(
pageserver: NeonPageserver, tenant_id: TenantId, timeline_id: TimelineId
):
ps_http = pageserver.http_client()
while True:
all_metrics = ps_http.get_metrics()
tl = all_metrics.query_all(
"pageserver_remote_timeline_client_calls_unfinished",
{
"tenant_id": str(tenant_id),
"timeline_id": str(timeline_id),
},
)
assert len(tl) > 0
log.info(f"upload queue for {tenant_id}/{timeline_id}: {tl}")
if all(m.value == 0 for m in tl):
return
time.sleep(0.2)

Some files were not shown because too many files have changed in this diff Show More