Compare commits

..

5 Commits

Author SHA1 Message Date
Conrad Ludgate
c0e1e1dd74 standardise logging 2025-06-25 15:54:26 +01:00
Conrad Ludgate
010cd34635 simplify error handling 2025-06-25 15:54:26 +01:00
Conrad Ludgate
11ffe4c86c remove unused error retries 2025-06-25 15:54:26 +01:00
Conrad Ludgate
d6a5085664 move h2::handshake outside of hypermechanism 2025-06-25 15:54:26 +01:00
Conrad Ludgate
16d9889a51 move authenticate outside of tokiomechanism 2025-06-25 15:54:26 +01:00
62 changed files with 816 additions and 1877 deletions

View File

@@ -32,7 +32,7 @@ permissions:
contents: read
jobs:
build-postgres:
build-pgxn:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
@@ -64,8 +64,8 @@ jobs:
id: cache_pg
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: pg-build-${{ matrix.postgres-version }}.tar
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-build-${{ matrix.postgres-version }}-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
path: pg_install/${{ matrix.postgres-version }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ matrix.postgres-version }}-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
@@ -89,24 +89,96 @@ jobs:
run: |
make postgres-${{ matrix.postgres-version }} -j$(sysctl -n hw.ncpu)
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
# It also doesn't support symbolic links.
- name: Build artifact tarball
- name: Build Neon Pg Ext ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
tar cf pg-build-${{ matrix.postgres-version }}.tar --exclude="*.o" pg_install build
make "neon-pg-ext-${{ matrix.postgres-version }}" -j$(sysctl -n hw.ncpu)
- name: Upload "pg-build--${{ matrix.postgres-version }}" artifact
- name: Get postgres headers ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres-headers-${{ matrix.postgres-version }} -j$(sysctl -n hw.ncpu)
- name: Upload "pg_install/${{ matrix.postgres-version }}" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: pg-build--${{ matrix.postgres-version }}
path: pg-build-${{ matrix.postgres-version }}.tar
name: pg_install--${{ matrix.postgres-version }}
path: pg_install/${{ matrix.postgres-version }}
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
make-all:
build-walproposer-lib:
if: |
contains(inputs.pg_versions, 'v17') || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
needs: [build-pgxn]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
BUILD_TYPE: release
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- name: Checkout main repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set pg v17 for caching
id: pg_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v17) | tee -a "${GITHUB_OUTPUT}"
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: build/walproposer-lib
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-v17
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
git submodule init vendor/postgres-v17
git submodule update --depth 1 --recursive
- name: Install build dependencies
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Build walproposer-lib (only for v17)
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run:
make walproposer-lib -j$(sysctl -n hw.ncpu)
- name: Upload "build/walproposer-lib" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: build--walproposer-lib
path: build/walproposer-lib
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
cargo-build:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_rust_code || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
@@ -114,7 +186,7 @@ jobs:
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
needs: [build-postgres]
needs: [build-pgxn, build-walproposer-lib]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
@@ -130,25 +202,41 @@ jobs:
with:
submodules: true
- name: Download "pg-build" artifacts
- name: Download "pg_install/v14" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
pattern: pg-build--*
merge-multiple: true
path: pg-build-tarballs
name: pg_install--v14
path: pg_install/v14
- name: Extract pg-build tarballs
run: find pg-build-tarballs -name "*.tar" -print0 | xargs -0 -n1 tar xvf
- name: Download "pg_install/v15" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v15
path: pg_install/v15
# Explicitly update the rust toolchain before running 'make'. The parallel make build can
# invoke 'cargo build' more than once in parallel, for different crates. That's OK, 'cargo'
# does its own locking to prevent concurrent builds from stepping on each other's
# toes. However, it will first try to update the toolchain, and that step is not locked the
# same way. To avoid two toolchain updates running in parallel and stepping on each other's
# toes, ensure that the toolchain is up-to-date beforehand.
- name: Update rust toolchain
- name: Download "pg_install/v16" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v16
path: pg_install/v16
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
- name: Download "build/walproposer-lib" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: build--walproposer-lib
path: build/walproposer-lib
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
run: |
rustup update
chmod +x pg_install/v*/bin/*
- name: Cache cargo deps
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
@@ -169,12 +257,8 @@ jobs:
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
# Build the neon-specific postgres extensions, and all the Rust bits.
#
# PostgreSQL was already built and cached separately. If the caching
# didn't work for some reason, this would build Postgres too.
- name: Build all
run: BUILD_TYPE=release make -j$(sysctl -n hw.ncpu) all
- name: Run cargo build
run: cargo build --all --release -j$(sysctl -n hw.ncpu)
- name: Check that no warnings are produced
run: ./run_clippy.sh

3
Cargo.lock generated
View File

@@ -4408,7 +4408,6 @@ dependencies = [
"postgres_backend",
"postgres_ffi_types",
"postgres_versioninfo",
"posthog_client_lite",
"rand 0.8.5",
"remote_storage",
"reqwest",
@@ -6816,7 +6815,6 @@ dependencies = [
"hex",
"http-utils",
"humantime",
"humantime-serde",
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
@@ -6827,7 +6825,6 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"postgres_connection",
"posthog_client_lite",
"rand 0.8.5",
"regex",
"reqwest",

View File

@@ -7,8 +7,8 @@ POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# CARGO_BUILD_FLAGS: Extra flags to pass to `cargo build`. `--locked`
# and `--features testing` are popular examples.
#
# CARGO_PROFILE: Set to override the cargo profile to use. By default,
# it is derived from BUILD_TYPE.
# CARGO_PROFILE: You can also set to override the cargo profile to
# use. By default, it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := build
@@ -147,11 +147,7 @@ postgres-configure-v15: $(BUILD_DIR)/v15/config.status
.PHONY: postgres-configure-v14
postgres-configure-v14: $(BUILD_DIR)/v14/config.status
# Install just the PostgreSQL header files into $(POSTGRES_INSTALL_DIR)/<version>/include
#
# This is implicitly included in the 'postgres-%' rule, but this can be handy if you
# want to just install the headers without building PostgreSQL, e.g. for building
# extensions.
# Install the PostgreSQL header files into $(POSTGRES_INSTALL_DIR)/<version>/include
.PHONY: postgres-headers-%
postgres-headers-%: postgres-configure-%
+@echo "Installing PostgreSQL $* headers"

View File

@@ -12,7 +12,6 @@ use std::{env, fs};
use anyhow::{Context, bail};
use clap::ValueEnum;
use pageserver_api::config::PostHogConfig;
use pem::Pem;
use postgres_backend::AuthType;
use reqwest::{Certificate, Url};
@@ -214,8 +213,6 @@ pub struct NeonStorageControllerConf {
pub timeline_safekeeper_count: Option<i64>,
pub posthog_config: Option<PostHogConfig>,
pub kick_secondary_downloads: Option<bool>,
}
@@ -248,7 +245,6 @@ impl Default for NeonStorageControllerConf {
use_https_safekeeper_api: false,
use_local_compute_notifications: true,
timeline_safekeeper_count: None,
posthog_config: None,
kick_secondary_downloads: None,
}
}

View File

@@ -642,18 +642,6 @@ impl StorageController {
args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
}
let mut envs = vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
];
if let Some(posthog_config) = &self.config.posthog_config {
envs.push((
"POSTHOG_CONFIG".to_string(),
serde_json::to_string(posthog_config)?,
));
}
println!("Starting storage controller");
background_process::start_process(
@@ -661,7 +649,10 @@ impl StorageController {
&instance_dir,
&self.env.storage_controller_bin(),
args,
envs,
vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
],
background_process::InitialPidFile::Create(self.pid_file(start_args.instance_id)),
&start_args.start_timeout,
|| async {

View File

@@ -1,396 +0,0 @@
# Memo: Endpoint Persistent Unlogged Files Storage
Created on 2024-11-05
Implemented on N/A
## Summary
A design for a storage system that allows storage of files required to make
Neon's Endpoints have a better experience at or after a reboot.
## Motivation
Several systems inside PostgreSQL (and Neon) need some persistent storage for
optimal workings across reboots and restarts, but still work without.
Examples are the query-level statistics files of `pg_stat_statements` in
`pg_stat/pg_stat_statements.stat`, and `pg_prewarm`'s `autoprewarm.blocks`.
We need a storage system that can store and manage these files for each
Endpoint, without necessarily granting users access to an unlimited storage
device.
## Goals
- Store known files for Endpoints with reasonable persistence.
_Data loss in this service, while annoying and bad for UX, won't lose any
customer's data._
## Non Goals (if relevant)
- This storage system does not need branching, file versioning, or other such
features. The files are as ephemeral to the timeline of the data as the
Endpoints that host the data.
- This storage system does not need to store _all_ user files, only 'known'
user files.
- This storage system does not need to be hosted fully inside Computes.
_Instead, this will be a separate component similar to Pageserver,
SafeKeeper, the S3 proxy used for dynamically loaded extensions, etc._
## Impacted components
- Compute needs new code to load and store these files in its lifetime.
- Control Plane needs to consider this new storage system when signalling
the deletion of an Endpoint, Timeline, or Tenant.
- Control Plane needs to consider this new storage system when it resets
or re-assigns an endpoint's timeline/branch state.
A new service is created: the Endpoint Persistent Unlogged Files Storage
service. This could be integrated in e.g. Pageserver or Control Plane, or a
separately hosted service.
## Proposed implementation
Endpoint-related data files are managed by a newly designed service (which
optionally is integrated in an existing service like Pageserver or Control
Plane), which stores data directly into S3 or any blob storage of choice.
Upon deletion of the Endpoint, or reassignment of the endpoint to a different
branch, this ephemeral data is dropped: the data stored may not match the
state of the branch's data after reassignment, and on endpoint deletion the
data won't have any use to the user.
Compute gets credentials (JWT token with Tenant, Timeline & Endpoint claims)
which it can use to authenticate to this new service and retrieve and store
data associated with this endpoint. This limited scope reduces leaks of data
across endpoints and timeline resets, and limits the ability of endpoints to
mess with other endpoints' data.
The path of this endpoint data in S3 is initially as follows:
s3://<regional-epufs-bucket>/
tenants/
<hex-tenant-id>/
tenants/
<hex-timeline-id>/
endpoints/
<endpoint-id>/
pgdata/
<file_path_in_pgdatadir>
For other blob storages an equivalent or similar path can be constructed.
### Reliability, failure modes and corner cases (if relevant)
Reliability is important, but not critical to the workings of Neon. The data
stored in this service will, when lost, reduce performance, but won't be a
cause of permanent data loss - only operational metadata is stored.
Most, if not all, blob storage services have sufficiently high persistence
guarantees to cater our need for persistence and uptime. The only concern with
blob storages is that the access latency is generally higher than local disk,
but for the object types stored (cache state, ...) I don't think this will be
much of an issue.
### Interaction/Sequence diagram (if relevant)
In these diagrams you can replace S3 with any persistent storage device of
choice, but S3 is chosen as representative name: The well-known and short name
of AWS' blob storage. Azure Blob Storage should work too, but it has a much
longer name making it less practical for the diagrams.
Write data:
```http
POST /tenants/<tenant-id>/timelines/<tl-id>/endpoints/<endpoint-id>/pgdata/<the-pgdata-path>
Host: epufs.svc.neon.local
<<<
200 OK
{
"version": "<opaque>", # opaque file version token, changes when the file contents change
"size": <bytes>,
}
```
```mermaid
sequenceDiagram
autonumber
participant co as Compute
participant ep as EPUFS
participant s3 as Blob Storage
co-->ep: Connect with credentials
co->>+ep: Store Unlogged Persistent File
opt is authenticated
ep->>s3: Write UPF to S3
end
ep->>-co: OK / Failure / Auth Failure
co-->ep: Cancel connection
```
Read data: (optional with cache-relevant request parameters, e.g. If-Modified-Since)
```http
GET /tenants/<tenant-id>/timelines/<tl-id>/endpoints/<endpoint-id>/pgdata/<the-pgdata-path>
Host: epufs.svc.neon.local
<<<
200 OK
<file data>
```
```mermaid
sequenceDiagram
autonumber
participant co as Compute
participant ep as EPUFS
participant s3 as Blob Storage
co->>+ep: Read Unlogged Persistent File
opt is authenticated
ep->>+s3: Request UPF from storage
s3->>-ep: Receive UPF from storage
end
ep->>-co: OK(response) / Failure(storage, auth, ...)
```
Compute Startup:
```mermaid
sequenceDiagram
autonumber
participant co as Compute
participant ps as Pageserver
participant ep as EPUFS
participant es as Extension server
note over co: Bind endpoint ep-xxx
par Get basebackup
co->>+ps: Request basebackup @ LSN
ps-)ps: Construct basebackup
ps->>-co: Receive basebackup TAR @ LSN
and Get startup-critical Unlogged Persistent Files
co->>+ep: Get all UPFs of endpoint ep-xxx
ep-)ep: Retrieve and gather all UPFs
ep->>-co: TAR of UPFs
and Get startup-critical extensions
loop For every startup-critical extension
co->>es: Get critical extension
es->>co: Receive critical extension
end
end
note over co: Start compute
```
CPlane ops:
```http
DELETE /tenants/<tenant-id>/timelines/<timeline-id>/endpoints/<endpoint-id>
Host: epufs.svc.neon.local
<<<
200 OK
{
"tenant": "<tenant-id>",
"timeline": "<timeline-id>",
"endpoint": "<endpoint-id>",
"deleted": {
"files": <count>,
"bytes": <count>,
},
}
```
```http
DELETE /tenants/<tenant-id>/timelines/<timeline-id>
Host: epufs.svc.neon.local
<<<
200 OK
{
"tenant": "<tenant-id>",
"timeline": "<timeline-id>",
"deleted": {
"files": <count>,
"bytes": <count>,
},
}
```
```http
DELETE /tenants/<tenant-id>
Host: epufs.svc.neon.local
<<<
200 OK
{
"tenant": "<tenant-id>",
"deleted": {
"files": <count>,
"bytes": <count>,
},
}
```
```mermaid
sequenceDiagram
autonumber
participant cp as Control Plane
participant ep as EPUFS
participant s3 as Blob Storage
alt Tenant deleted
cp-)ep: Tenant deleted
loop For every object associated with removed tenant
ep->>s3: Remove data of deleted tenant from Storage
end
opt
ep-)cp: Tenant cleanup complete
end
alt Timeline deleted
cp-)ep: Timeline deleted
loop For every object associated with removed timeline
ep->>s3: Remove data of deleted timeline from Storage
end
opt
ep-)cp: Timeline cleanup complete
end
else Endpoint reassigned or removed
cp->>+ep: Endpoint reassigned
loop For every object associated with reassigned/removed endpoint
ep->>s3: Remove data from Storage
end
ep->>-cp: Cleanup complete
end
```
### Scalability (if relevant)
Provisionally: As this service is going to be part of compute startup, this
service should be able to quickly respond to all requests. Therefore this
service is deployed to every AZ we host Computes in, and Computes communicate
(generally) only to the EPUFS endpoint of the AZ they're hosted in.
Local caching of frequently restarted endpoints' data or metadata may be
needed for best performance. However, due to the regional nature of stored
data but zonal nature of the service deployment, we should be careful when we
implement any local caching, as it is possible that computes in AZ 1 will
update data originally written and thus cached by AZ 2. Cache version tests
and invalidation is therefore required if we want to roll out caching to this
service, which is too broad a scope for an MVC. This is why caching is left
out of scope for this RFC, and should be considered separately after this RFC
is implemented.
### Security implications (if relevant)
This service must be able to authenticate users at least by Tenant ID,
Timeline ID and Endpoint ID. This will use the existing JWT infrastructure of
Compute, which will be upgraded to the extent needed to support Timeline- and
Endpoint-based claims.
The service requires unlimited access to (a prefix of) a blob storage bucket,
and thus must be hosted outside the Compute VM sandbox.
A service that generates pre-signed request URLs for Compute to download the
data from that URL is likely problematic, too: Compute would be able to write
unlimited data to the bucket, or exfiltrate this signed URL to get read/write
access to specific objects in this bucket, which would still effectively give
users access to the S3 bucket (but with improved access logging).
There may be a use case for transferring data associated with one endpoint to
another endpoint (e.g. to make one endpoint warm its caches with the state of
another endpoint), but that's not currently in scope, and specific needs may
be solved through out-of-line communication of data or pre-signed URLs.
### Unresolved questions (if relevant)
Caching of files is not in the implementation scope of the document, but
should at some future point be considered to maximize performance.
## Alternative implementation (if relevant)
Several ideas have come up to solve this issue:
### Use AUXfile
One prevalent idea was to WAL-log the files using our AUXfile mechanism.
Benefits:
+ We already have this storage mechanism
Demerits:
- It isn't available on read replicas
- Additional WAL will be consumed during shutdown and after the shutdown
checkpoint, which needs PG modifications to work without panics.
- It increases the data we need to manage in our versioned storage, thus
causing higher storage costs with higher retention due to duplication at
the storage layer.
### Sign URLs for read/write operations, instead of proxying them
Benefits:
+ The service can be implemented with a much reduced IO budget
Demerits:
- Users could get access to these signed credentials
- Not all blob storage services may implement URL signing
### Give endpoints each their own directly accessed block volume
Benefits:
+ Easier to integrate for PostgreSQL
Demerits:
- Little control on data size and contents
- Potentially problematic as we'd need to store data all across the pgdata
directory.
- EBS is not a good candidate
- Attaches in 10s of seconds, if not more; i.e. too cold to start
- Shared EBS volumes are a no-go, as you'd have to schedule the endpoint
with users of the same EBS volumes, which can't work with VM migration
- EBS storage costs are very high (>80$/kilotenant when using a
volume/tenant)
- EBS volumes can't be mounted across AZ boundaries
- Bucket per endpoint is unfeasible
- S3 buckets are priced at $20/month per 1k, which we could better spend
on developers.
- Allocating service accounts takes time (100s of ms), and service accounts
are a limited resource, too; so they're not a good candidate to allocate
on a per-endpoint basis.
- Giving credentials limited to prefix has similar issues as the pre-signed
URL approach.
- Bucket DNS lookup will fill DNS caches and put pressure on DNS lookup
much more than our current systems would.
- Volumes bound by hypervisor are unlikely
- This requires significant investment and increased software on the
hypervisor.
- It is unclear if we can attach volumes after boot, i.e. for pooled
instances.
### Put the files into a table
Benefits:
+ Mostly already available in PostgreSQL
Demerits:
- Uses WAL
- Can't be used after shutdown checkpoint
- Needs a RW endpoint, and table & catalog access to write to this data
- Gets hit with DB size limitations
- Depending on user acces:
- Inaccessible:
The user doesn't have control over database size caused by
these systems.
- Accessible:
The user can corrupt these files and cause the system to crash while
user-corrupted files are present, thus increasing on-call overhead.
## Definition of Done (if relevant)
This project is done if we have:
- One S3 bucket equivalent per region, which stores this per-endpoint data.
- A new service endpoint in at least every AZ, which indirectly grants
endpoints access to the data stored for these endpoints in these buckets.
- Compute writes & reads temp-data at shutdown and startup, respectively, for
at least the pg_prewarm or lfc_prewarm state files.
- Cleanup of endpoint data is triggered when the endpoint is deleted or is
detached from its current timeline.

View File

@@ -19,7 +19,6 @@ byteorder.workspace = true
utils.workspace = true
postgres_ffi_types.workspace = true
postgres_versioninfo.workspace = true
posthog_client_lite.workspace = true
enum-map.workspace = true
strum.workspace = true
strum_macros.workspace = true

View File

@@ -4,7 +4,6 @@ use camino::Utf8PathBuf;
mod tests;
use const_format::formatcp;
use posthog_client_lite::PostHogClientConfig;
pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000;
pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}");
pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898;
@@ -64,66 +63,25 @@ impl Display for NodeMetadata {
}
}
/// PostHog integration config. This is used in pageserver, storcon, and neon_local.
/// Ensure backward compatibility when adding new fields.
/// PostHog integration config.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<String>,
pub project_id: String,
/// Server-side (private) API key
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub server_api_key: Option<String>,
pub server_api_key: String,
/// Client-side (public) API key
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub client_api_key: Option<String>,
pub client_api_key: String,
/// Private API URL
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub private_api_url: Option<String>,
pub private_api_url: String,
/// Public API URL
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub public_api_url: Option<String>,
/// Refresh interval for the feature flag spec.
/// The storcon will push the feature flag spec to the pageserver. If the pageserver does not receive
/// the spec for `refresh_interval`, it will fetch the spec from the PostHog API.
#[serde(default)]
pub public_api_url: String,
/// Refresh interval for the feature flag spec
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "humantime_serde")]
pub refresh_interval: Option<Duration>,
}
impl PostHogConfig {
pub fn try_into_posthog_config(self) -> Result<PostHogClientConfig, &'static str> {
let Some(project_id) = self.project_id else {
return Err("project_id is required");
};
let Some(server_api_key) = self.server_api_key else {
return Err("server_api_key is required");
};
let Some(client_api_key) = self.client_api_key else {
return Err("client_api_key is required");
};
let Some(private_api_url) = self.private_api_url else {
return Err("private_api_url is required");
};
let Some(public_api_url) = self.public_api_url else {
return Err("public_api_url is required");
};
Ok(PostHogClientConfig {
project_id,
server_api_key,
client_api_key,
private_api_url,
public_api_url,
})
}
}
/// `pageserver.toml`
///
/// We use serde derive with `#[serde(default)]` to generate a deserializer
@@ -409,9 +367,6 @@ pub struct BasebackupCacheConfig {
// TODO(diko): support max_entry_size_bytes.
// pub max_entry_size_bytes: u64,
pub max_size_entries: usize,
/// Size of the channel used to send prepare requests to the basebackup cache worker.
/// If exceeded, new prepare requests will be dropped.
pub prepare_channel_size: usize,
}
impl Default for BasebackupCacheConfig {
@@ -421,7 +376,6 @@ impl Default for BasebackupCacheConfig {
max_total_size_bytes: 1024 * 1024 * 1024, // 1 GiB
// max_entry_size_bytes: 16 * 1024 * 1024, // 16 MiB
max_size_entries: 1000,
prepare_channel_size: 100,
}
}
}

View File

@@ -1,22 +1,17 @@
//! A background loop that fetches feature flags from PostHog and updates the feature store.
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info_span};
use crate::{
CaptureEvent, FeatureStore, LocalEvaluationResponse, PostHogClient, PostHogClientConfig,
};
use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig};
/// A background loop that fetches feature flags from PostHog and updates the feature store.
pub struct FeatureResolverBackgroundLoop {
posthog_client: PostHogClient,
feature_store: ArcSwap<(SystemTime, Arc<FeatureStore>)>,
feature_store: ArcSwap<FeatureStore>,
cancel: CancellationToken,
}
@@ -24,35 +19,11 @@ impl FeatureResolverBackgroundLoop {
pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self {
Self {
posthog_client: PostHogClient::new(config),
feature_store: ArcSwap::new(Arc::new((
SystemTime::UNIX_EPOCH,
Arc::new(FeatureStore::new()),
))),
feature_store: ArcSwap::new(Arc::new(FeatureStore::new())),
cancel: shutdown_pageserver,
}
}
/// Update the feature store with a new feature flag spec bypassing the normal refresh loop.
pub fn update(&self, spec: String) -> anyhow::Result<()> {
let resp: LocalEvaluationResponse = serde_json::from_str(&spec)?;
self.update_feature_store_nofail(resp, "http_propagate");
Ok(())
}
fn update_feature_store_nofail(&self, resp: LocalEvaluationResponse, source: &'static str) {
let project_id = self.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
self.feature_store
.store(Arc::new((SystemTime::now(), Arc::new(feature_store))));
tracing::info!("Feature flag updated from {}", source);
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec from {}: {}", source, e);
}
}
}
pub fn spawn(
self: Arc<Self>,
handle: &tokio::runtime::Handle,
@@ -76,17 +47,6 @@ impl FeatureResolverBackgroundLoop {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
{
let last_update = this.feature_store.load().0;
if let Ok(elapsed) = last_update.elapsed() {
if elapsed < refresh_period {
tracing::debug!(
"Skipping feature flag refresh because it's too soon"
);
continue;
}
}
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
@@ -98,7 +58,16 @@ impl FeatureResolverBackgroundLoop {
continue;
}
};
this.update_feature_store_nofail(resp, "refresh_loop");
let project_id = this.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec: {}", e);
}
}
}
tracing::info!("PostHog feature resolver stopped");
}
@@ -123,6 +92,6 @@ impl FeatureResolverBackgroundLoop {
}
pub fn feature_store(&self) -> Arc<FeatureStore> {
self.feature_store.load().1.clone()
self.feature_store.load_full()
}
}

View File

@@ -544,8 +544,17 @@ impl PostHogClient {
self.config.server_api_key.starts_with("phs_")
}
/// Get the raw JSON spec, same as `get_feature_flags_local_evaluation` but without parsing.
pub async fn get_feature_flags_local_evaluation_raw(&self) -> anyhow::Result<String> {
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
/// - <https://posthog.com/docs/api/feature-flags#get-api-projects-project_id-feature_flags-local_evaluation>
/// - <https://posthog.com/docs/feature-flags/local-evaluation>
///
/// The handling logic in [`FeatureStore`] mostly follows the Python API implementation.
/// See `_compute_flag_locally` in <https://github.com/PostHog/posthog-python/blob/master/posthog/client.py>
pub async fn get_feature_flags_local_evaluation(
&self,
) -> anyhow::Result<LocalEvaluationResponse> {
// BASE_URL/api/projects/:project_id/feature_flags/local_evaluation
// with bearer token of self.server_api_key
// OR
@@ -579,22 +588,7 @@ impl PostHogClient {
body
));
}
Ok(body)
}
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
/// - <https://posthog.com/docs/api/feature-flags#get-api-projects-project_id-feature_flags-local_evaluation>
/// - <https://posthog.com/docs/feature-flags/local-evaluation>
///
/// The handling logic in [`FeatureStore`] mostly follows the Python API implementation.
/// See `_compute_flag_locally` in <https://github.com/PostHog/posthog-python/blob/master/posthog/client.py>
pub async fn get_feature_flags_local_evaluation(
&self,
) -> Result<LocalEvaluationResponse, anyhow::Error> {
let raw = self.get_feature_flags_local_evaluation_raw().await?;
Ok(serde_json::from_str(&raw)?)
Ok(serde_json::from_str(&body)?)
}
/// Capture an event. This will only be used to report the feature flag usage back to PostHog, though

View File

@@ -1,12 +1,13 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::config::{Host, SslMode};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::connect_tls::connect_tls;
@@ -46,13 +47,7 @@ where
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
let raw = connect_raw(stream, config).await?;
let socket_config = SocketConfig {
host_addr,
@@ -61,24 +56,46 @@ where
connect_timeout: config.connect_timeout,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
Ok((client, connection))
Ok(raw.into_managed_conn(socket_config, config.ssl_mode))
}
impl<S, T> RawConnection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn into_managed_conn(
self,
socket_config: SocketConfig,
ssl_mode: SslMode,
) -> (Client, Connection<S, T>) {
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = self;
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
(client, connection)
}
}

View File

@@ -844,13 +844,4 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn update_feature_flag_spec(&self, spec: String) -> Result<()> {
let uri = format!("{}/v1/feature_flag_spec", self.mgmt_api_endpoint);
self.request(Method::POST, uri, spec)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -110,19 +110,6 @@ message GetBaseBackupRequest {
bool replica = 2;
// If true, include relation files in the base backup. Mainly for debugging and tests.
bool full = 3;
// Compression algorithm to use. Base backups send a compressed payload instead of using gRPC
// compression, so that we can cache compressed backups on the server.
BaseBackupCompression compression = 4;
}
// Base backup compression algorithms.
enum BaseBackupCompression {
// Unknown algorithm. Used when clients send an unsupported algorithm.
BASE_BACKUP_COMPRESSION_UNKNOWN = 0;
// No compression.
BASE_BACKUP_COMPRESSION_NONE = 1;
// GZIP compression.
BASE_BACKUP_COMPRESSION_GZIP = 2;
}
// Base backup response chunk, returned as an ordered stream.

View File

@@ -95,6 +95,7 @@ impl Client {
if let Some(compression) = compression {
// TODO: benchmark this (including network latency).
// TODO: consider enabling compression by default.
client = client
.accept_compressed(compression)
.send_compressed(compression);

View File

@@ -191,21 +191,15 @@ pub struct GetBaseBackupRequest {
pub replica: bool,
/// If true, include relation files in the base backup. Mainly for debugging and tests.
pub full: bool,
/// Compression algorithm to use. Base backups send a compressed payload instead of using gRPC
/// compression, so that we can cache compressed backups on the server.
pub compression: BaseBackupCompression,
}
impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
impl From<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
fn from(pb: proto::GetBaseBackupRequest) -> Self {
Self {
lsn: (pb.lsn != 0).then_some(Lsn(pb.lsn)),
replica: pb.replica,
full: pb.full,
compression: pb.compression.try_into()?,
})
}
}
}
@@ -215,55 +209,10 @@ impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
lsn: request.lsn.unwrap_or_default().0,
replica: request.replica,
full: request.full,
compression: request.compression.into(),
}
}
}
/// Base backup compression algorithm.
#[derive(Clone, Copy, Debug)]
pub enum BaseBackupCompression {
None,
Gzip,
}
impl TryFrom<proto::BaseBackupCompression> for BaseBackupCompression {
type Error = ProtocolError;
fn try_from(pb: proto::BaseBackupCompression) -> Result<Self, Self::Error> {
match pb {
proto::BaseBackupCompression::Unknown => Err(ProtocolError::invalid("compression", pb)),
proto::BaseBackupCompression::None => Ok(Self::None),
proto::BaseBackupCompression::Gzip => Ok(Self::Gzip),
}
}
}
impl TryFrom<i32> for BaseBackupCompression {
type Error = ProtocolError;
fn try_from(compression: i32) -> Result<Self, Self::Error> {
proto::BaseBackupCompression::try_from(compression)
.map_err(|_| ProtocolError::invalid("compression", compression))
.and_then(Self::try_from)
}
}
impl From<BaseBackupCompression> for proto::BaseBackupCompression {
fn from(compression: BaseBackupCompression) -> Self {
match compression {
BaseBackupCompression::None => Self::None,
BaseBackupCompression::Gzip => Self::Gzip,
}
}
}
impl From<BaseBackupCompression> for i32 {
fn from(compression: BaseBackupCompression) -> Self {
proto::BaseBackupCompression::from(compression).into()
}
}
pub type GetBaseBackupResponseChunk = Bytes;
impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {

View File

@@ -317,7 +317,6 @@ impl Client for LibpqClient {
/// A gRPC Pageserver client.
struct GrpcClient {
inner: page_api::Client,
compression: page_api::BaseBackupCompression,
}
impl GrpcClient {
@@ -332,14 +331,10 @@ impl GrpcClient {
ttid.timeline_id,
ShardIndex::unsharded(),
None,
None, // NB: uses payload compression
compression.then_some(tonic::codec::CompressionEncoding::Zstd),
)
.await?;
let compression = match compression {
true => page_api::BaseBackupCompression::Gzip,
false => page_api::BaseBackupCompression::None,
};
Ok(Self { inner, compression })
Ok(Self { inner })
}
}
@@ -353,7 +348,6 @@ impl Client for GrpcClient {
lsn,
replica: false,
full: false,
compression: self.compression,
};
let stream = self.inner.get_base_backup(req).await?;
Ok(Box::pin(StreamReader::new(

View File

@@ -14,7 +14,6 @@ use std::fmt::Write as FmtWrite;
use std::time::{Instant, SystemTime};
use anyhow::{Context, anyhow};
use async_compression::tokio::write::GzipEncoder;
use bytes::{BufMut, Bytes, BytesMut};
use fail::fail_point;
use pageserver_api::key::{Key, rel_block_to_key};
@@ -26,7 +25,8 @@ use postgres_ffi::{
};
use postgres_ffi_types::constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID};
use postgres_ffi_types::forknum::{INIT_FORKNUM, MAIN_FORKNUM};
use tokio::io::{self, AsyncWrite, AsyncWriteExt as _};
use tokio::io;
use tokio::io::AsyncWrite;
use tokio_tar::{Builder, EntryType, Header};
use tracing::*;
use utils::lsn::Lsn;
@@ -97,7 +97,6 @@ impl From<BasebackupError> for tonic::Status {
/// * When working without safekeepers. In this situation it is important to match the lsn
/// we are taking basebackup on with the lsn that is used in pageserver's walreceiver
/// to start the replication.
#[allow(clippy::too_many_arguments)]
pub async fn send_basebackup_tarball<'a, W>(
write: &'a mut W,
timeline: &'a Timeline,
@@ -105,7 +104,6 @@ pub async fn send_basebackup_tarball<'a, W>(
prev_lsn: Option<Lsn>,
full_backup: bool,
replica: bool,
gzip_level: Option<async_compression::Level>,
ctx: &'a RequestContext,
) -> Result<(), BasebackupError>
where
@@ -124,7 +122,7 @@ where
// prev_lsn value; that happens if the timeline was just branched from
// an old LSN and it doesn't have any WAL of its own yet. We will set
// prev_lsn to Lsn(0) if we cannot provide the correct value.
let (backup_prev, lsn) = if let Some(req_lsn) = req_lsn {
let (backup_prev, backup_lsn) = if let Some(req_lsn) = req_lsn {
// Backup was requested at a particular LSN. The caller should've
// already checked that it's a valid LSN.
@@ -145,7 +143,7 @@ where
};
// Consolidate the derived and the provided prev_lsn values
let prev_record_lsn = if let Some(provided_prev_lsn) = prev_lsn {
let prev_lsn = if let Some(provided_prev_lsn) = prev_lsn {
if backup_prev != Lsn(0) && backup_prev != provided_prev_lsn {
return Err(BasebackupError::Server(anyhow!(
"backup_prev {backup_prev} != provided_prev_lsn {provided_prev_lsn}"
@@ -157,55 +155,30 @@ where
};
info!(
"taking basebackup lsn={lsn}, prev_lsn={prev_record_lsn} \
(full_backup={full_backup}, replica={replica}, gzip={gzip_level:?})",
);
let span = info_span!("send_tarball", backup_lsn=%lsn);
let io_concurrency = IoConcurrency::spawn_from_conf(
timeline.conf.get_vectored_concurrent_io,
timeline
.gate
.enter()
.map_err(|_| BasebackupError::Shutdown)?,
"taking basebackup lsn={}, prev_lsn={} (full_backup={}, replica={})",
backup_lsn, prev_lsn, full_backup, replica
);
if let Some(gzip_level) = gzip_level {
let mut encoder = GzipEncoder::with_quality(write, gzip_level);
Basebackup {
ar: Builder::new_non_terminated(&mut encoder),
timeline,
lsn,
prev_record_lsn,
full_backup,
replica,
ctx,
io_concurrency,
}
let basebackup = Basebackup {
ar: Builder::new_non_terminated(write),
timeline,
lsn: backup_lsn,
prev_record_lsn: prev_lsn,
full_backup,
replica,
ctx,
io_concurrency: IoConcurrency::spawn_from_conf(
timeline.conf.get_vectored_concurrent_io,
timeline
.gate
.enter()
.map_err(|_| BasebackupError::Shutdown)?,
),
};
basebackup
.send_tarball()
.instrument(span)
.await?;
encoder
.shutdown()
.await
.map_err(|err| BasebackupError::Client(err, "gzip"))?;
} else {
Basebackup {
ar: Builder::new_non_terminated(write),
timeline,
lsn,
prev_record_lsn,
full_backup,
replica,
ctx,
io_concurrency,
}
.send_tarball()
.instrument(span)
.await?;
}
Ok(())
.instrument(info_span!("send_tarball", backup_lsn=%backup_lsn))
.await
}
/// This is short-living object only for the time of tarball creation,

View File

@@ -1,12 +1,13 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use async_compression::tokio::write::GzipEncoder;
use camino::{Utf8Path, Utf8PathBuf};
use metrics::core::{AtomicU64, GenericCounter};
use pageserver_api::{config::BasebackupCacheConfig, models::TenantState};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::mpsc::{Receiver, Sender, error::TrySendError},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
};
use tokio_util::sync::CancellationToken;
use utils::{
@@ -19,8 +20,8 @@ use crate::{
basebackup::send_basebackup_tarball,
context::{DownloadBehavior, RequestContext},
metrics::{
BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE,
BASEBACKUP_CACHE_READ, BASEBACKUP_CACHE_SIZE,
BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_READ,
BASEBACKUP_CACHE_SIZE,
},
task_mgr::TaskKind,
tenant::{
@@ -35,8 +36,8 @@ pub struct BasebackupPrepareRequest {
pub lsn: Lsn,
}
pub type BasebackupPrepareSender = Sender<BasebackupPrepareRequest>;
pub type BasebackupPrepareReceiver = Receiver<BasebackupPrepareRequest>;
pub type BasebackupPrepareSender = UnboundedSender<BasebackupPrepareRequest>;
pub type BasebackupPrepareReceiver = UnboundedReceiver<BasebackupPrepareRequest>;
#[derive(Clone)]
struct CacheEntry {
@@ -60,65 +61,40 @@ struct CacheEntry {
/// and ~1 RPS for get requests.
pub struct BasebackupCache {
data_dir: Utf8PathBuf,
config: Option<BasebackupCacheConfig>,
entries: std::sync::Mutex<HashMap<TenantTimelineId, CacheEntry>>,
prepare_sender: BasebackupPrepareSender,
read_hit_count: GenericCounter<AtomicU64>,
read_miss_count: GenericCounter<AtomicU64>,
read_err_count: GenericCounter<AtomicU64>,
prepare_skip_count: GenericCounter<AtomicU64>,
}
impl BasebackupCache {
/// Create a new BasebackupCache instance.
/// Also returns a BasebackupPrepareReceiver which is needed to start
/// the background task.
/// The cache is initialized from the data_dir in the background task.
/// The cache will return `None` for any get requests until the initialization is complete.
/// The background task is spawned separately using [`Self::spawn_background_task`]
/// to avoid a circular dependency between the cache and the tenant manager.
pub fn new(
/// Creates a BasebackupCache and spawns the background task.
/// The initialization of the cache is performed in the background and does not
/// block the caller. The cache will return `None` for any get requests until
/// initialization is complete.
pub fn spawn(
runtime_handle: &tokio::runtime::Handle,
data_dir: Utf8PathBuf,
config: Option<BasebackupCacheConfig>,
) -> (Arc<Self>, BasebackupPrepareReceiver) {
let chan_size = config.as_ref().map(|c| c.max_size_entries).unwrap_or(1);
let (prepare_sender, prepare_receiver) = tokio::sync::mpsc::channel(chan_size);
prepare_receiver: BasebackupPrepareReceiver,
tenant_manager: Arc<TenantManager>,
cancel: CancellationToken,
) -> Arc<Self> {
let cache = Arc::new(BasebackupCache {
data_dir,
config,
entries: std::sync::Mutex::new(HashMap::new()),
prepare_sender,
read_hit_count: BASEBACKUP_CACHE_READ.with_label_values(&["hit"]),
read_miss_count: BASEBACKUP_CACHE_READ.with_label_values(&["miss"]),
read_err_count: BASEBACKUP_CACHE_READ.with_label_values(&["error"]),
prepare_skip_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["skip"]),
});
(cache, prepare_receiver)
}
/// Spawns the background task.
/// The background task initializes the cache from the disk,
/// processes prepare requests, and cleans up outdated cache entries.
/// Noop if the cache is disabled (config is None).
pub fn spawn_background_task(
self: Arc<Self>,
runtime_handle: &tokio::runtime::Handle,
prepare_receiver: BasebackupPrepareReceiver,
tenant_manager: Arc<TenantManager>,
cancel: CancellationToken,
) {
if let Some(config) = self.config.clone() {
if let Some(config) = config {
let background = BackgroundTask {
c: self,
c: cache.clone(),
config,
tenant_manager,
@@ -133,45 +109,8 @@ impl BasebackupCache {
};
runtime_handle.spawn(background.run(prepare_receiver));
}
}
/// Send a basebackup prepare request to the background task.
/// The basebackup will be prepared asynchronously, it does not block the caller.
/// The request will be skipped if any cache limits are exceeded.
pub fn send_prepare(&self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, lsn: Lsn) {
let req = BasebackupPrepareRequest {
tenant_shard_id,
timeline_id,
lsn,
};
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.inc();
let res = self.prepare_sender.try_send(req);
if let Err(e) = res {
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec();
self.prepare_skip_count.inc();
match e {
TrySendError::Full(_) => {
// Basebackup prepares are pretty rare, normally we should not hit this.
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
%timeline_id,
%lsn,
"Basebackup prepare channel is full, skipping the request"
);
}
TrySendError::Closed(_) => {
// Normal during shutdown, not critical.
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
%timeline_id,
%lsn,
"Basebackup prepare channel is closed, skipping the request"
);
}
}
}
cache
}
/// Gets a basebackup entry from the cache.
@@ -184,10 +123,6 @@ impl BasebackupCache {
timeline_id: TimelineId,
lsn: Lsn,
) -> Option<tokio::fs::File> {
if !self.is_enabled() {
return None;
}
// Fast path. Check if the entry exists using the in-memory state.
let tti = TenantTimelineId::new(tenant_id, timeline_id);
if self.entries.lock().unwrap().get(&tti).map(|e| e.lsn) != Some(lsn) {
@@ -215,10 +150,6 @@ impl BasebackupCache {
}
}
pub fn is_enabled(&self) -> bool {
self.config.is_some()
}
// Private methods.
fn entry_filename(tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn) -> String {
@@ -436,7 +367,6 @@ impl BackgroundTask {
loop {
tokio::select! {
Some(req) = prepare_receiver.recv() => {
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec();
if let Err(err) = self.prepare_basebackup(
req.tenant_shard_id,
req.timeline_id,
@@ -664,6 +594,13 @@ impl BackgroundTask {
let file = tokio::fs::File::create(entry_tmp_path).await?;
let mut writer = BufWriter::new(file);
let mut encoder = GzipEncoder::with_quality(
&mut writer,
// Level::Best because compression is not on the hot path of basebackup requests.
// The decompression is almost not affected by the compression level.
async_compression::Level::Best,
);
// We may receive a request before the WAL record is applied to the timeline.
// Wait for the requested LSN to be applied.
timeline
@@ -676,19 +613,17 @@ impl BackgroundTask {
.await?;
send_basebackup_tarball(
&mut writer,
&mut encoder,
timeline,
Some(req_lsn),
None,
false,
false,
// Level::Best because compression is not on the hot path of basebackup requests.
// The decompression is almost not affected by the compression level.
Some(async_compression::Level::Best),
&ctx,
)
.await?;
encoder.shutdown().await?;
writer.flush().await?;
writer.into_inner().sync_all().await?;

View File

@@ -569,10 +569,8 @@ fn start_pageserver(
pageserver::l0_flush::L0FlushGlobalState::new(conf.l0_flush.clone());
// Scan the local 'tenants/' directory and start loading the tenants
let (basebackup_cache, basebackup_prepare_receiver) = BasebackupCache::new(
conf.basebackup_cache_dir(),
conf.basebackup_cache_config.clone(),
);
let (basebackup_prepare_sender, basebackup_prepare_receiver) =
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
@@ -584,7 +582,7 @@ fn start_pageserver(
remote_storage: remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_cache: Arc::clone(&basebackup_cache),
basebackup_prepare_sender,
feature_resolver: feature_resolver.clone(),
},
shutdown_pageserver.clone(),
@@ -592,8 +590,10 @@ fn start_pageserver(
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
basebackup_cache.spawn_background_task(
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),
conf.basebackup_cache_dir(),
conf.basebackup_cache_config.clone(),
basebackup_prepare_receiver,
Arc::clone(&tenant_manager),
shutdown_pageserver.child_token(),
@@ -806,6 +806,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache,
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for

View File

@@ -762,40 +762,4 @@ mod tests {
let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir);
assert_eq!(result.is_ok(), is_valid);
}
#[test]
fn test_config_posthog_config_is_valid() {
let input = r#"
control_plane_api = "http://localhost:6666"
[posthog_config]
server_api_key = "phs_AAA"
client_api_key = "phc_BBB"
project_id = "000"
private_api_url = "https://us.posthog.com"
public_api_url = "https://us.i.posthog.com"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("posthogconfig is valid");
let workdir = Utf8PathBuf::from("/nonexistent");
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
#[test]
fn test_config_posthog_incomplete_config_is_valid() {
let input = r#"
control_plane_api = "http://localhost:6666"
[posthog_config]
server_api_key = "phs_AAA"
private_api_url = "https://us.posthog.com"
public_api_url = "https://us.i.posthog.com"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("posthogconfig is valid");
let workdir = Utf8PathBuf::from("/nonexistent");
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
}

View File

@@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogEvaluationError,
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
};
use remote_storage::RemoteStorageKind;
@@ -31,13 +31,6 @@ impl FeatureResolver {
}
}
pub fn update(&self, spec: String) -> anyhow::Result<()> {
if let Some(inner) = &self.inner {
inner.update(spec)?;
}
Ok(())
}
pub fn spawn(
conf: &PageServerConf,
shutdown_pageserver: CancellationToken,
@@ -45,24 +38,16 @@ impl FeatureResolver {
) -> anyhow::Result<Self> {
// DO NOT block in this function: make it return as fast as possible to avoid startup delays.
if let Some(posthog_config) = &conf.posthog_config {
let posthog_client_config = match posthog_config.clone().try_into_posthog_config() {
Ok(config) => config,
Err(e) => {
tracing::warn!(
"invalid posthog config, skipping posthog integration: {}",
e
);
return Ok(FeatureResolver {
inner: None,
internal_properties: None,
force_overrides_for_testing: Arc::new(ArcSwap::new(Arc::new(
HashMap::new(),
))),
});
}
};
let inner =
FeatureResolverBackgroundLoop::new(posthog_client_config, shutdown_pageserver);
let inner = FeatureResolverBackgroundLoop::new(
PostHogClientConfig {
server_api_key: posthog_config.server_api_key.clone(),
client_api_key: posthog_config.client_api_key.clone(),
project_id: posthog_config.project_id.clone(),
private_api_url: posthog_config.private_api_url.clone(),
public_api_url: posthog_config.public_api_url.clone(),
},
shutdown_pageserver,
);
let inner = Arc::new(inner);
// The properties shared by all tenants on this pageserver.

View File

@@ -3743,20 +3743,6 @@ async fn force_override_feature_flag_for_testing_delete(
json_response(StatusCode::OK, ())
}
async fn update_feature_flag_spec(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let body = json_request(&mut request).await?;
let state = get_state(&request);
state
.feature_resolver
.update(body)
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -4142,8 +4128,5 @@ pub fn make_router(
.delete("/v1/feature_flag/:flag_key", |r| {
testing_api_handler("force override feature flag - delete", r, force_override_feature_flag_for_testing_delete)
})
.post("/v1/feature_flag_spec", |r| {
api_handler(r, update_feature_flag_spec)
})
.any(handler_404))
}

View File

@@ -4439,14 +4439,6 @@ pub(crate) static BASEBACKUP_CACHE_SIZE: Lazy<UIntGauge> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub(crate) static BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE: Lazy<UIntGauge> = Lazy::new(|| {
register_uint_gauge!(
"pageserver_basebackup_cache_prepare_queue_size",
"Number of requests in the basebackup prepare channel"
)
.expect("failed to define a metric")
});
static PAGESERVER_CONFIG_IGNORED_ITEMS: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_config_ignored_items",

View File

@@ -12,7 +12,8 @@ use std::task::{Context, Poll};
use std::time::{Duration, Instant, SystemTime};
use std::{io, str};
use anyhow::{Context as _, bail};
use anyhow::{Context as _, anyhow, bail};
use async_compression::tokio::write::GzipEncoder;
use bytes::{Buf as _, BufMut as _, BytesMut};
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
@@ -62,6 +63,7 @@ use utils::{failpoint_support, span_record};
use crate::auth::check_permission;
use crate::basebackup::{self, BasebackupError};
use crate::basebackup_cache::BasebackupCache;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
@@ -136,6 +138,7 @@ pub fn spawn(
perf_trace_dispatch: Option<Dispatch>,
tcp_listener: tokio::net::TcpListener,
tls_config: Option<Arc<rustls::ServerConfig>>,
basebackup_cache: Arc<BasebackupCache>,
) -> Listener {
let cancel = CancellationToken::new();
let libpq_ctx = RequestContext::todo_child(
@@ -157,6 +160,7 @@ pub fn spawn(
conf.pg_auth_type,
tls_config,
conf.page_service_pipelining.clone(),
basebackup_cache,
libpq_ctx,
cancel.clone(),
)
@@ -215,6 +219,7 @@ pub async fn libpq_listener_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
) -> Connections {
@@ -258,6 +263,7 @@ pub async fn libpq_listener_main(
auth_type,
tls_config.clone(),
pipelining_config.clone(),
Arc::clone(&basebackup_cache),
connection_ctx,
connections_cancel.child_token(),
gate_guard,
@@ -300,6 +306,7 @@ async fn page_service_conn_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -365,6 +372,7 @@ async fn page_service_conn_main(
pipelining_config,
conf.get_vectored_concurrent_io,
perf_span_fields,
basebackup_cache,
connection_ctx,
cancel.clone(),
gate_guard,
@@ -418,6 +426,8 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
basebackup_cache: Arc<BasebackupCache>,
gate_guard: GateGuard,
}
@@ -903,6 +913,7 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
perf_span_fields: ConnectionPerfSpanFields,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -916,6 +927,7 @@ impl PageServerHandler {
cancel,
pipelining_config,
get_vectored_concurrent_io,
basebackup_cache,
gate_guard,
}
}
@@ -2601,16 +2613,26 @@ impl PageServerHandler {
prev_lsn,
full_backup,
replica,
None,
&ctx,
)
.await?;
} else {
let mut writer = BufWriter::new(pgb.copyout_writer());
let cached = timeline
.get_cached_basebackup_if_enabled(lsn, prev_lsn, full_backup, replica, gzip)
.await;
let cached = {
// Basebackup is cached only for this combination of parameters.
if timeline.is_basebackup_cache_enabled()
&& gzip
&& lsn.is_some()
&& prev_lsn.is_none()
{
self.basebackup_cache
.get(tenant_id, timeline_id, lsn.unwrap())
.await
} else {
None
}
};
if let Some(mut cached) = cached {
from_cache = true;
@@ -2619,6 +2641,31 @@ impl PageServerHandler {
.map_err(|err| {
BasebackupError::Client(err, "handle_basebackup_request,cached,copy")
})?;
} else if gzip {
let mut encoder = GzipEncoder::with_quality(
&mut writer,
// NOTE using fast compression because it's on the critical path
// for compute startup. For an empty database, we get
// <100KB with this method. The Level::Best compression method
// gives us <20KB, but maybe we should add basebackup caching
// on compute shutdown first.
async_compression::Level::Fastest,
);
basebackup::send_basebackup_tarball(
&mut encoder,
&timeline,
lsn,
prev_lsn,
full_backup,
replica,
&ctx,
)
.await?;
// shutdown the encoder to ensure the gzip footer is written
encoder
.shutdown()
.await
.map_err(|e| QueryError::Disconnected(ConnectionError::Io(e)))?;
} else {
basebackup::send_basebackup_tarball(
&mut writer,
@@ -2627,11 +2674,6 @@ impl PageServerHandler {
prev_lsn,
full_backup,
replica,
// NB: using fast compression because it's on the critical path for compute
// startup. For an empty database, we get <100KB with this method. The
// Level::Best compression method gives us <20KB, but maybe we should add
// basebackup caching on compute shutdown first.
gzip.then_some(async_compression::Level::Fastest),
&ctx,
)
.await?;
@@ -3511,7 +3553,7 @@ impl proto::PageService for GrpcPageServiceHandler {
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?;
let req: page_api::GetBaseBackupRequest = req.into_inner().into();
span_record!(lsn=?req.lsn);
@@ -3537,50 +3579,20 @@ impl proto::PageService for GrpcPageServiceHandler {
let span = Span::current();
let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE);
let jh = tokio::spawn(async move {
let gzip_level = match req.compression {
page_api::BaseBackupCompression::None => None,
// NB: using fast compression because it's on the critical path for compute
// startup. For an empty database, we get <100KB with this method. The
// Level::Best compression method gives us <20KB, but maybe we should add
// basebackup caching on compute shutdown first.
page_api::BaseBackupCompression::Gzip => Some(async_compression::Level::Fastest),
};
// Check for a cached basebackup.
let cached = timeline
.get_cached_basebackup_if_enabled(
req.lsn,
None,
req.full,
req.replica,
gzip_level.is_some(),
)
.await;
let result = if let Some(mut cached) = cached {
// If we have a cached basebackup, send it.
tokio::io::copy(&mut cached, &mut simplex_write)
.await
.map(|_| ())
.map_err(|err| BasebackupError::Client(err, "cached,copy"))
} else {
basebackup::send_basebackup_tarball(
&mut simplex_write,
&timeline,
req.lsn,
None,
req.full,
req.replica,
gzip_level,
&ctx,
)
.instrument(span) // propagate request span
.await
};
simplex_write
.shutdown()
.await
.map_err(|err| BasebackupError::Client(err, "simplex_write"))?;
let result = basebackup::send_basebackup_tarball(
&mut simplex_write,
&timeline,
req.lsn,
None,
req.full,
req.replica,
&ctx,
)
.instrument(span) // propagate request span
.await;
simplex_write.shutdown().await.map_err(|err| {
BasebackupError::Server(anyhow!("simplex shutdown failed: {err}"))
})?;
result
});

View File

@@ -80,7 +80,7 @@ use self::timeline::uninit::{TimelineCreateGuard, TimelineExclusionError, Uninit
use self::timeline::{
EvictionTaskTenantState, GcCutoffs, TimelineDeleteProgress, TimelineResources, WaitLsnError,
};
use crate::basebackup_cache::BasebackupCache;
use crate::basebackup_cache::BasebackupPrepareSender;
use crate::config::PageServerConf;
use crate::context;
use crate::context::RequestContextBuilder;
@@ -162,7 +162,7 @@ pub struct TenantSharedResources {
pub remote_storage: GenericRemoteStorage,
pub deletion_queue_client: DeletionQueueClient,
pub l0_flush_global_state: L0FlushGlobalState,
pub basebackup_cache: Arc<BasebackupCache>,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
@@ -331,7 +331,7 @@ pub struct TenantShard {
deletion_queue_client: DeletionQueueClient,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_cache: Arc<BasebackupCache>,
basebackup_prepare_sender: BasebackupPrepareSender,
/// Cached logical sizes updated updated on each [`TenantShard::gather_size_inputs`].
cached_logical_sizes: tokio::sync::Mutex<HashMap<(TimelineId, Lsn), u64>>,
@@ -1363,7 +1363,7 @@ impl TenantShard {
remote_storage,
deletion_queue_client,
l0_flush_global_state,
basebackup_cache,
basebackup_prepare_sender,
feature_resolver,
} = resources;
@@ -1380,7 +1380,7 @@ impl TenantShard {
remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_cache,
basebackup_prepare_sender,
feature_resolver,
));
@@ -4380,7 +4380,7 @@ impl TenantShard {
remote_storage: GenericRemoteStorage,
deletion_queue_client: DeletionQueueClient,
l0_flush_global_state: L0FlushGlobalState,
basebackup_cache: Arc<BasebackupCache>,
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
) -> TenantShard {
assert!(!attached_conf.location.generation.is_none());
@@ -4485,7 +4485,7 @@ impl TenantShard {
ongoing_timeline_detach: std::sync::Mutex::default(),
gc_block: Default::default(),
l0_flush_global_state,
basebackup_cache,
basebackup_prepare_sender,
feature_resolver,
}
}
@@ -5414,7 +5414,7 @@ impl TenantShard {
pagestream_throttle_metrics: self.pagestream_throttle_metrics.clone(),
l0_compaction_trigger: self.l0_compaction_trigger.clone(),
l0_flush_global_state: self.l0_flush_global_state.clone(),
basebackup_cache: self.basebackup_cache.clone(),
basebackup_prepare_sender: self.basebackup_prepare_sender.clone(),
feature_resolver: self.feature_resolver.clone(),
}
}
@@ -6000,7 +6000,7 @@ pub(crate) mod harness {
) -> anyhow::Result<Arc<TenantShard>> {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None);
let (basebackup_requst_sender, _) = tokio::sync::mpsc::unbounded_channel();
let tenant = Arc::new(TenantShard::new(
TenantState::Attaching,
@@ -6018,7 +6018,7 @@ pub(crate) mod harness {
self.deletion_queue.new_client(),
// TODO: ideally we should run all unit tests with both configs
L0FlushGlobalState::new(L0FlushConfig::default()),
basebackup_cache,
basebackup_requst_sender,
FeatureResolver::new_disabled(),
));

View File

@@ -2891,18 +2891,14 @@ mod tests {
use std::collections::BTreeMap;
use std::sync::Arc;
use camino::Utf8PathBuf;
use storage_broker::BrokerClientChannel;
use tracing::Instrument;
use super::super::harness::TenantHarness;
use super::TenantsMap;
use crate::{
basebackup_cache::BasebackupCache,
tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
},
use crate::tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
};
#[tokio::test(start_paused = true)]
@@ -2928,7 +2924,9 @@ mod tests {
// Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually
// permit it to proceed: that will stick the tenant in InProgress
let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None);
let (basebackup_prepare_sender, _) = tokio::sync::mpsc::unbounded_channel::<
crate::basebackup_cache::BasebackupPrepareRequest,
>();
let tenant_manager = TenantManager {
tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)),
@@ -2942,7 +2940,7 @@ mod tests {
l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new(
h.conf.l0_flush.clone(),
),
basebackup_cache,
basebackup_prepare_sender,
feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(),
},
cancel: tokio_util::sync::CancellationToken::new(),

View File

@@ -95,12 +95,12 @@ use super::storage_layer::{LayerFringe, LayerVisibilityHint, ReadableLayer};
use super::tasks::log_compaction_error;
use super::upload_queue::NotInitialized;
use super::{
AttachedTenantConf, GcError, HeatMapTimeline, MaybeOffloaded,
AttachedTenantConf, BasebackupPrepareSender, GcError, HeatMapTimeline, MaybeOffloaded,
debug_assert_current_span_has_tenant_and_timeline_id,
};
use crate::PERF_TRACE_TARGET;
use crate::aux_file::AuxFileSizeEstimator;
use crate::basebackup_cache::BasebackupCache;
use crate::basebackup_cache::BasebackupPrepareRequest;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
@@ -201,7 +201,7 @@ pub struct TimelineResources {
pub pagestream_throttle_metrics: Arc<crate::metrics::tenant_throttling::Pagestream>,
pub l0_compaction_trigger: Arc<Notify>,
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
pub basebackup_cache: Arc<BasebackupCache>,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
@@ -448,7 +448,7 @@ pub struct Timeline {
wait_lsn_log_slow: tokio::sync::Semaphore,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_cache: Arc<BasebackupCache>,
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
}
@@ -2500,37 +2500,6 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.basebackup_cache_enabled)
}
/// Try to get a basebackup from the on-disk cache.
pub(crate) async fn get_cached_basebackup(&self, lsn: Lsn) -> Option<tokio::fs::File> {
self.basebackup_cache
.get(self.tenant_shard_id.tenant_id, self.timeline_id, lsn)
.await
}
/// Convenience method to attempt fetching a basebackup for the timeline if enabled and safe for
/// the given request parameters.
///
/// TODO: consider moving this onto GrpcPageServiceHandler once the libpq handler is gone.
pub async fn get_cached_basebackup_if_enabled(
&self,
lsn: Option<Lsn>,
prev_lsn: Option<Lsn>,
full: bool,
replica: bool,
gzip: bool,
) -> Option<tokio::fs::File> {
if !self.is_basebackup_cache_enabled() || !self.basebackup_cache.is_enabled() {
return None;
}
// We have to know which LSN to fetch the basebackup for.
let lsn = lsn?;
// We only cache gzipped, non-full basebackups for primary computes with automatic prev_lsn.
if prev_lsn.is_some() || full || replica || !gzip {
return None;
}
self.get_cached_basebackup(lsn).await
}
/// Prepare basebackup for the given LSN and store it in the basebackup cache.
/// The method is asynchronous and returns immediately.
/// The actual basebackup preparation is performed in the background
@@ -2552,8 +2521,17 @@ impl Timeline {
return;
}
self.basebackup_cache
.send_prepare(self.tenant_shard_id, self.timeline_id, lsn);
let res = self
.basebackup_prepare_sender
.send(BasebackupPrepareRequest {
tenant_shard_id: self.tenant_shard_id,
timeline_id: self.timeline_id,
lsn,
});
if let Err(e) = res {
// May happen during shutdown, it's not critical.
info!("Failed to send shutdown checkpoint: {e:#}");
}
}
}
@@ -3110,7 +3088,7 @@ impl Timeline {
wait_lsn_log_slow: tokio::sync::Semaphore::new(1),
basebackup_cache: resources.basebackup_cache,
basebackup_prepare_sender: resources.basebackup_prepare_sender,
feature_resolver: resources.feature_resolver,
};
@@ -4680,16 +4658,6 @@ impl Timeline {
mut layer_flush_start_rx: tokio::sync::watch::Receiver<(u64, Lsn)>,
ctx: &RequestContext,
) {
// Always notify waiters about the flush loop exiting since the loop might stop
// when the timeline hasn't been cancelled.
let scopeguard_rx = layer_flush_start_rx.clone();
scopeguard::defer! {
let (flush_counter, _) = *scopeguard_rx.borrow();
let _ = self
.layer_flush_done_tx
.send_replace((flush_counter, Err(FlushLayerError::Cancelled)));
}
// Subscribe to L0 delta layer updates, for compaction backpressure.
let mut watch_l0 = match self
.layers
@@ -4719,6 +4687,9 @@ impl Timeline {
let result = loop {
if self.cancel.is_cancelled() {
info!("dropping out of flush loop for timeline shutdown");
// Note: we do not bother transmitting into [`layer_flush_done_tx`], because
// anyone waiting on that will respect self.cancel as well: they will stop
// waiting at the same time we as drop out of this loop.
return;
}

View File

@@ -1295,8 +1295,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
if (iteration_hits != 0)
{
/* chunk offset (#
of pages) into the LFC file */
/* chunk offset (# of pages) into the LFC file */
off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk;
int nwrite = iov_last_used - first_block_in_chunk_read;
/* offset of first IOV */
@@ -1314,6 +1313,16 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
lfc_disable("read");
return -1;
}
/*
* We successfully read the pages we know were valid when we
* started reading; now mark those pages as read
*/
for (int i = first_block_in_chunk_read; i < iov_last_used; i++)
{
if (BITMAP_ISSET(chunk_mask, i))
BITMAP_SET(mask, buf_offset + i);
}
}
/* Place entry to the head of LRU list */
@@ -1331,15 +1340,6 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
{
lfc_ctl->time_read += io_time_us;
inc_page_cache_read_wait(io_time_us);
/*
* We successfully read the pages we know were valid when we
* started reading; now mark those pages as read
*/
for (int i = first_block_in_chunk_read; i < iov_last_used; i++)
{
if (BITMAP_ISSET(chunk_mask, i))
BITMAP_SET(mask, buf_offset + i);
}
}
CriticalAssert(entry->access_count > 0);

View File

@@ -279,6 +279,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -236,6 +236,7 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -123,6 +123,12 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -149,31 +155,40 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for irsa authentication
/// redis host for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_host: Option<String>,
/// redis port for irsa authentication
/// redis port for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name for irsa authentication
/// redis cluster name, used in aws elasticache
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id for irsa authentication
/// redis user_id, used in aws elasticache
#[clap(long)]
redis_user_id: Option<String>,
/// aws region for irsa authentication
/// aws region to retrieve credentials
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -185,12 +200,6 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -203,7 +212,6 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -323,7 +331,7 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let redis_client = configure_redis(&args).await?;
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
@@ -378,6 +386,13 @@ pub async fn run() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -457,7 +472,6 @@ pub async fn run() -> anyhow::Result<()> {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
// maintenance tasks. these never return unless there's an error
@@ -481,17 +495,32 @@ pub async fn run() -> anyhow::Result<()> {
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client {
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
@@ -516,12 +545,14 @@ pub async fn run() -> anyhow::Result<()> {
}
}
}
}
// listen for notifications of new projects/endpoints/branches
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
@@ -650,6 +681,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -811,18 +843,21 @@ fn build_auth_backend(
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<Option<ConnectionWithCredentialsProvider>> {
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let redis_client = match &*args.redis_auth_type {
"plain" => match &args.redis_plain {
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_plain to be set");
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
"irsa" => match (&args.redis_host, args.redis_port) {
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
@@ -846,12 +881,18 @@ async fn configure_redis(
bail!("redis-host and redis-port must be specified together");
}
},
auth_type => {
bail!("unknown auth type {auth_type:?} given")
_ => {
bail!("unknown auth type given");
}
};
Ok(redis_client)
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
}
#[cfg(test)]

View File

@@ -350,7 +350,7 @@ impl CancellationHandler {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
pub cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}

View File

@@ -86,6 +86,14 @@ pub(crate) enum ConnectionError {
#[error("error acquiring resource permit: {0}")]
TooManyConnectionAttempts(#[from] ApiLockError),
#[cfg(test)]
#[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
TestError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
},
}
impl UserFacingError for ConnectionError {
@@ -96,6 +104,8 @@ impl UserFacingError for ConnectionError {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
#[cfg(test)]
ConnectionError::TestError { .. } => self.to_string(),
}
}
}
@@ -106,6 +116,8 @@ impl ReportableError for ConnectionError {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]
ConnectionError::TestError { kind, .. } => *kind,
}
}
}
@@ -252,6 +264,19 @@ impl AuthInfo {
.await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let RawConnection {
stream: _,
parameters,
@@ -260,8 +285,6 @@ impl AuthInfo {
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
@@ -288,6 +311,7 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
direct: bool,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
@@ -330,7 +354,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
@@ -357,7 +381,7 @@ pub struct PostgresSettings {
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub stream: MaybeRustlsStream,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,
@@ -373,23 +397,12 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
direct: bool,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
let (socket_addr, stream) = self.connect_raw(config, direct).await?;
drop(pause);
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
let connection = ComputeConnection {
stream,
socket_addr,

View File

@@ -11,8 +11,6 @@ use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
@@ -22,7 +20,6 @@ pub enum TlsError {
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
@@ -35,6 +32,7 @@ pub async fn connect_tls<S, T>(
mode: SslMode,
tls: &T,
host: &str,
direct: bool,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
@@ -49,7 +47,7 @@ where
SslMode::Prefer | SslMode::Require => {}
}
if !request_tls(&mut stream).await? {
if !direct && !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
@@ -57,7 +55,6 @@ where
return Ok(MaybeTlsStream::Raw(stream));
}
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
let c = tls.make_tls_connect(host).map_err(std::io::Error::other)?;
Ok(MaybeTlsStream::Tls(c.connect(stream).boxed().await?))
}

View File

@@ -22,6 +22,7 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,

View File

@@ -89,7 +89,12 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
@@ -217,6 +222,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&node_info,
config.wake_compute_retry_config,

View File

@@ -46,6 +46,7 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -93,6 +94,7 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -122,7 +124,12 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -138,6 +145,7 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -171,7 +179,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: String,
region: &'static str,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: String::new(),
region: value.region,
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,7 +167,6 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -233,17 +232,12 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config, &region).await
worker_inner(storage, rx, parquet_config).await
}
}
@@ -263,7 +257,6 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -284,8 +277,7 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
while let Some(row) = rx.next().await {
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -541,7 +533,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: String::new(),
region: "us-east-1",
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -573,9 +565,7 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -263,7 +263,12 @@ impl NeonControlPlaneClient {
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => host.into(),
Some(host) => {
if rustls::pki_types::DnsName::try_from_str(&host).is_err() {
return Err(WakeComputeError::BadComputeAddress(host.into_boxed_str()));
}
host.into()
}
None => host.into(),
};

View File

@@ -77,8 +77,9 @@ impl NodeInfo {
&self,
ctx: &RequestContext,
config: &ComputeConfig,
direct: bool,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config).await
self.conn_info.connect(ctx, &self.aux, config, direct).await
}
}

View File

@@ -1,18 +1,15 @@
use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::types::Host;
@@ -35,42 +32,34 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node
node_info.invalidate()
}
#[async_trait]
pub(crate) trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError>;
) -> Result<Self::Connection, compute::ConnectionError>;
}
pub(crate) struct TcpMechanism {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
// whether to negotiate TLS for postgres protocol.
pub(crate) direct: bool,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
type Connection = ComputeConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty
))]
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, Self::Error> {
) -> Result<ComputeConnection, compute::ConnectionError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, config, self.direct).await)
}
}
@@ -82,11 +71,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBacken
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
) -> Result<M::Connection, compute::ConnectionError> {
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
@@ -120,7 +105,7 @@ where
},
num_retries.into(),
);
return Err(err.into());
return Err(err);
}
node_info
} else {
@@ -161,7 +146,7 @@ where
},
num_retries.into(),
);
return Err(e.into());
return Err(e);
}
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);

View File

@@ -122,7 +122,12 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
@@ -353,6 +358,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,

View File

@@ -1,4 +1,3 @@
use std::error::Error;
use std::io;
use tokio::time;
@@ -31,85 +30,25 @@ impl CouldRetry for io::Error {
}
}
impl CouldRetry for postgres_client::error::DbError {
fn could_retry(&self) -> bool {
use postgres_client::error::SqlState;
matches!(
self.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
fn should_retry_wake_compute(&self) -> bool {
use postgres_client::error::SqlState;
// Here are errors that happens after the user successfully authenticated to the database.
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
let non_retriable_pg_errors = matches!(
self.code(),
&SqlState::TOO_MANY_CONNECTIONS
| &SqlState::OUT_OF_MEMORY
| &SqlState::SYNTAX_ERROR
| &SqlState::T_R_SERIALIZATION_FAILURE
| &SqlState::INVALID_CATALOG_NAME
| &SqlState::INVALID_SCHEMA_NAME
| &SqlState::INVALID_PARAMETER_VALUE,
);
if non_retriable_pg_errors {
return false;
}
// PGBouncer errors that should not trigger a wake_compute retry.
if self.code() == &SqlState::PROTOCOL_VIOLATION {
// Source for the error message:
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
return !self
.message()
.contains("no more connections allowed (max_client_conn)");
}
true
}
}
impl CouldRetry for postgres_client::Error {
fn could_retry(&self) -> bool {
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
io::Error::could_retry(io_err)
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::could_retry(db_err)
} else {
false
}
}
}
impl ShouldRetryWakeCompute for postgres_client::Error {
fn should_retry_wake_compute(&self) -> bool {
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::should_retry_wake_compute(db_err)
} else {
// likely an IO error. Possible the compute has shutdown and the
// cache is stale.
true
}
}
}
impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { retryable, .. } => *retryable,
}
}
}
impl ShouldRetryWakeCompute for compute::ConnectionError {
fn should_retry_wake_compute(&self) -> bool {
match self {
// the cache entry was not checked for validity
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { wakeable, .. } => *wakeable,
_ => true,
}
}
@@ -120,56 +59,3 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
.base_delay
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
}
#[cfg(test)]
mod tests {
use postgres_client::error::{DbError, SqlState};
use super::ShouldRetryWakeCompute;
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.
let non_retry_states = [
SqlState::TOO_MANY_CONNECTIONS,
SqlState::OUT_OF_MEMORY,
SqlState::SYNTAX_ERROR,
SqlState::T_R_SERIALIZATION_FAILURE,
SqlState::INVALID_CATALOG_NAME,
SqlState::INVALID_SCHEMA_NAME,
SqlState::INVALID_PARAMETER_VALUE,
];
for state in non_retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
!err.should_retry_wake_compute(),
"State {state:?} unexpectedly retried"
);
}
// Errors coming from pgbouncer should not trigger a wake_compute retry
let non_retry_pgbouncer_errors = ["no more connections allowed (max_client_conn)"];
for error in non_retry_pgbouncer_errors {
let err = DbError::new_test_error(SqlState::PROTOCOL_VIOLATION, error.to_string());
assert!(
!err.should_retry_wake_compute(),
"PGBouncer error {error:?} unexpectedly retried"
);
}
// These SQLStates should trigger a wake_compute retry.
let retry_states = [
SqlState::CONNECTION_FAILURE,
SqlState::CONNECTION_EXCEPTION,
SqlState::CONNECTION_DOES_NOT_EXIST,
SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
];
for state in retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
err.should_retry_wake_compute(),
"State {state:?} unexpectedly skipped retry"
);
}
}
}

View File

@@ -10,7 +10,7 @@ use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use retry::retry_after;
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
@@ -20,6 +20,7 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::compute::ConnectionError;
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
@@ -423,71 +424,36 @@ impl TestConnectMechanism {
#[derive(Debug)]
struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for TestConnectError {}
impl CouldRetry for TestConnectError {
fn could_retry(&self) -> bool {
self.retryable
}
}
impl ShouldRetryWakeCompute for TestConnectError {
fn should_retry_wake_compute(&self) -> bool {
self.wakeable
}
}
#[async_trait]
impl ConnectMechanism for TestConnectMechanism {
type Connection = TestConnection;
type ConnectError = TestConnectError;
type Error = anyhow::Error;
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
) -> Result<Self::Connection, ConnectionError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError {
ConnectAction::Retry => Err(ConnectionError::TestError {
retryable: true,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::RetryNoWake => Err(TestConnectError {
ConnectAction::RetryNoWake => Err(ConnectionError::TestError {
retryable: true,
wakeable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
ConnectAction::Fail => Err(ConnectionError::TestError {
retryable: false,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::FailNoWake => Err(TestConnectError {
ConnectAction::FailNoWake => Err(ConnectionError::TestError {
retryable: false,
wakeable: false,
kind: ErrorKind::Compute,

View File

@@ -139,6 +139,12 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -5,9 +5,11 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -28,8 +30,11 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -44,6 +49,11 @@ impl RedisKVClient {
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => e,

View File

@@ -141,19 +141,29 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -266,7 +276,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.cache.increment_active_listeners().await;
handler.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -287,11 +297,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.cache.decrement_active_listeners().await;
handler.decrement_active_listeners().await;
return Ok(());
}
}
handler.cache.decrement_active_listeners().await;
handler.decrement_active_listeners().await;
}
}
@@ -300,11 +310,12 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache);
let handler = MessageHandler::new(cache, region_id);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -1,17 +1,12 @@
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::config::SslMode;
use postgres_client::SocketConfig;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use rand::rngs::OsRng;
use rustls::pki_types::{DnsName, ServerName};
use tokio::net::{TcpStream, lookup_host};
use tokio_rustls::TlsConnector;
use tracing::field::display;
use tracing::{debug, info};
@@ -23,21 +18,19 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute::{self, ComputeConnection};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::proxy::connect_compute::TcpMechanism;
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
@@ -157,11 +150,6 @@ impl PoolingBackend {
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_compute(
&self,
ctx: &RequestContext,
@@ -181,30 +169,24 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
keys: keys.keys,
direct: false,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
authenticate(ctx, &self.pool, &conn_info, keys.keys, connection, conn_id).await
}
// Wake up the destination if needed
#[tracing::instrument(skip_all, fields(
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestContext,
@@ -216,7 +198,6 @@ impl PoolingBackend {
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
@@ -226,19 +207,19 @@ impl PoolingBackend {
)),
options: conn_info.user_info.options.clone(),
});
crate::proxy::connect_compute::connect_to_compute(
let connection = crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
&TcpMechanism {
locks: &self.config.connect_compute_locks,
direct: true,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await
.await?;
h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await
}
/// Connect to postgres over localhost.
@@ -248,10 +229,6 @@ impl PoolingBackend {
/// # Panics
///
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestContext,
@@ -373,6 +350,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connect to compute")]
ConnectError(#[from] compute::ConnectionError),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -394,8 +373,6 @@ pub(crate) enum HttpConnError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
}
@@ -403,6 +380,7 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
@@ -419,6 +397,7 @@ impl ReportableError for HttpConnError {
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectError(p) => p.to_string_client(),
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
@@ -434,36 +413,9 @@ impl UserFacingError for HttpConnError {
}
}
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,
HttpConnError::TooManyConnectionAttempts(_) => false,
}
}
}
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
}
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
@@ -475,208 +427,106 @@ impl UserFacingError for LocalProxyConnError {
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
async fn authenticate(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: &ConnInfo,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<postgres_client::Client>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
// client config with stubbed connect info.
let mut config = postgres_client::Config::new(String::new(), 0);
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
if let ComputeCredentialKeys::AuthKeys(auth_keys) = keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = config.authenticate(compute.stream).await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let (client, connection) = connection.into_managed_conn(
SocketConfig {
host_addr: Some(compute.socket_addr.ip()),
host: postgres_client::config::Host::Tcp(compute.hostname.to_string()),
port: compute.socket_addr.port(),
connect_timeout: None,
},
compute.ssl_mode,
);
Ok(poll_client(
pool.clone(),
ctx,
conn_info.clone(),
client,
connection,
conn_id,
compute.aux,
))
}
async fn connect_http2(
host_addr: Option<IpAddr>,
host: &str,
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?
.collect(),
};
let mut last_err = None;
let mut addrs = addrs.into_iter();
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
}
};
let stream = if let Some(tls) = tls {
let host = DnsName::try_from(host)
.map_err(io::Error::other)
.map_err(LocalProxyConnError::Io)?
.to_owned();
let stream = TlsConnector::from(tls.clone())
.connect(ServerName::DnsName(host), stream)
.await
.map_err(LocalProxyConnError::Io)?;
Box::pin(stream) as AsyncRW
} else {
Box::pin(stream) as AsyncRW
async fn h2handshake(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: &ConnInfo,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
let stream = match compute.stream {
MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW,
MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW,
};
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
.await
.map_err(LocalProxyConnError::H2)?;
drop(pause);
Ok((client, connection))
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
Ok(poll_http2_client(
pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
compute.aux,
))
}

View File

@@ -69,7 +69,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %conn_id, pid=client.get_process_id(), compute_id=%aux.compute_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");

View File

@@ -518,15 +518,14 @@ impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
info!(
conn_id = %client.get_conn_id(),
pid = client.inner.get_process_id(),
compute_id = &*client.aux.compute_id,
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
match client.get_data() {

View File

@@ -6,7 +6,7 @@ use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use smol_str::ToSmolStr;
use tracing::{Instrument, debug, error, info, info_span};
use tracing::{Instrument, error, info, info_span};
use super::AsyncRW;
use super::backend::HttpConnError;
@@ -115,7 +115,6 @@ impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
#[expect(unused_results)]
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
@@ -132,10 +131,13 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
return result;
};
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
info!(
conn_id = %client.conn.conn_id,
compute_id = &*client.conn.aux.compute_id,
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
@@ -197,7 +199,7 @@ pub(crate) fn poll_http2_client(
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %conn_id, compute_id=%aux.compute_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -229,6 +231,8 @@ pub(crate) fn poll_http2_client(
tokio::spawn(
async move {
info!("new local proxy connection");
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {

View File

@@ -30,7 +30,7 @@ use serde_json::value::RawValue;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, warn};
use tracing::{Instrument, error, info, info_span, warn};
use super::backend::HttpConnError;
use super::conn_pool_lib::{
@@ -107,15 +107,13 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
info!(
pid = client.inner.get_process_id(),
conn_id = %client.get_conn_id(),
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"local_pool: reusing connection '{conn_info}'"
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
);
match client.get_data() {

View File

@@ -417,7 +417,12 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
ctx.set_user_agent(
request
@@ -457,7 +462,12 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let span = ctx.span();
let testodrome_id = request

View File

@@ -60,7 +60,7 @@ mod private {
}
}
pub struct RustlsStream<S>(Box<TlsStream<S>>);
pub struct RustlsStream<S>(pub Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where

View File

@@ -1,6 +1,5 @@
use std::cmp::min;
use std::io::{self, ErrorKind};
use std::ops::RangeInclusive;
use std::sync::Arc;
use anyhow::{Context, Result, anyhow, bail};
@@ -35,7 +34,7 @@ use crate::control_file::CONTROL_FILE_NAME;
use crate::state::{EvictionState, TimelinePersistentState};
use crate::timeline::{Timeline, TimelineError, WalResidentTimeline};
use crate::timelines_global_map::{create_temp_timeline_dir, validate_temp_timeline};
use crate::wal_storage::{open_wal_file, wal_file_paths};
use crate::wal_storage::open_wal_file;
use crate::{GlobalTimelines, debug_dump, wal_backup};
/// Stream tar archive of timeline to tx.
@@ -96,8 +95,8 @@ pub async fn stream_snapshot(
/// State needed while streaming the snapshot.
pub struct SnapshotContext {
/// The interval of segment numbers. If None, the timeline hasn't had writes yet, so only send the control file
pub from_to_segno: Option<RangeInclusive<XLogSegNo>>,
pub from_segno: XLogSegNo, // including
pub upto_segno: XLogSegNo, // including
pub term: Term,
pub last_log_term: Term,
pub flush_lsn: Lsn,
@@ -175,35 +174,23 @@ pub async fn stream_snapshot_resident_guts(
.await?;
pausable_failpoint!("sk-snapshot-after-list-pausable");
if let Some(from_to_segno) = &bctx.from_to_segno {
let tli_dir = tli.get_timeline_dir();
info!(
"sending {} segments [{:#X}-{:#X}], term={}, last_log_term={}, flush_lsn={}",
from_to_segno.end() - from_to_segno.start() + 1,
from_to_segno.start(),
from_to_segno.end(),
bctx.term,
bctx.last_log_term,
bctx.flush_lsn,
);
for segno in from_to_segno.clone() {
let Some((mut sf, is_partial)) =
open_wal_file(&tli_dir, segno, bctx.wal_seg_size).await?
else {
// File is not found
let (wal_file_path, _wal_file_partial_path) =
wal_file_paths(&tli_dir, segno, bctx.wal_seg_size);
tracing::warn!("couldn't find WAL segment file {wal_file_path}");
bail!("couldn't find WAL segment file {wal_file_path}")
};
let mut wal_file_name = XLogFileName(PG_TLI, segno, bctx.wal_seg_size);
if is_partial {
wal_file_name.push_str(".partial");
}
ar.append_file(&wal_file_name, &mut sf).await?;
let tli_dir = tli.get_timeline_dir();
info!(
"sending {} segments [{:#X}-{:#X}], term={}, last_log_term={}, flush_lsn={}",
bctx.upto_segno - bctx.from_segno + 1,
bctx.from_segno,
bctx.upto_segno,
bctx.term,
bctx.last_log_term,
bctx.flush_lsn,
);
for segno in bctx.from_segno..=bctx.upto_segno {
let (mut sf, is_partial) = open_wal_file(&tli_dir, segno, bctx.wal_seg_size).await?;
let mut wal_file_name = XLogFileName(PG_TLI, segno, bctx.wal_seg_size);
if is_partial {
wal_file_name.push_str(".partial");
}
} else {
info!("Not including any segments into the snapshot");
ar.append_file(&wal_file_name, &mut sf).await?;
}
// Do the term check before ar.finish to make archive corrupted in case of
@@ -351,26 +338,19 @@ impl WalResidentTimeline {
// removed further than `backup_lsn`. Since we're holding shared_state
// lock and setting `wal_removal_on_hold` later, it guarantees that WAL
// won't be removed until we're done.
let timeline_state = shared_state.sk.state();
let from_lsn = min(
timeline_state.remote_consistent_lsn,
timeline_state.backup_lsn,
);
let flush_lsn = shared_state.sk.flush_lsn();
let (send_segments, msg) = if from_lsn == Lsn::INVALID {
(false, "snapshot is called on uninitialized timeline")
} else {
(true, "timeline is initialized")
};
tracing::info!(
remote_consistent_lsn=%timeline_state.remote_consistent_lsn,
backup_lsn=%timeline_state.backup_lsn,
%flush_lsn,
"{msg}"
shared_state.sk.state().remote_consistent_lsn,
shared_state.sk.state().backup_lsn,
);
if from_lsn == Lsn::INVALID {
// this is possible if snapshot is called before handling first
// elected message
bail!("snapshot is called on uninitialized timeline");
}
let from_segno = from_lsn.segment_number(wal_seg_size);
let term = shared_state.sk.state().acceptor_state.term;
let last_log_term = shared_state.sk.last_log_term();
let flush_lsn = shared_state.sk.flush_lsn();
let upto_segno = flush_lsn.segment_number(wal_seg_size);
// have some limit on max number of segments as a sanity check
const MAX_ALLOWED_SEGS: u64 = 1000;
@@ -396,9 +376,9 @@ impl WalResidentTimeline {
drop(shared_state);
let tli_copy = self.wal_residence_guard().await?;
let from_to_segno = send_segments.then_some(from_segno..=upto_segno);
let bctx = SnapshotContext {
from_to_segno,
from_segno,
upto_segno,
term,
last_log_term,
flush_lsn,

View File

@@ -9,7 +9,7 @@
use std::cmp::{max, min};
use std::future::Future;
use std::io::{ErrorKind, SeekFrom};
use std::io::{self, SeekFrom};
use std::pin::Pin;
use anyhow::{Context, Result, bail};
@@ -794,13 +794,26 @@ impl WalReader {
// Try to open local file, if we may have WAL locally
if self.pos >= self.local_start_lsn {
let res = open_wal_file(&self.timeline_dir, segno, self.wal_seg_size).await?;
if let Some((mut file, _)) = res {
file.seek(SeekFrom::Start(xlogoff as u64)).await?;
return Ok(Box::pin(file));
} else {
// NotFound is expected, fall through to remote read
}
let res = open_wal_file(&self.timeline_dir, segno, self.wal_seg_size).await;
match res {
Ok((mut file, _)) => {
file.seek(SeekFrom::Start(xlogoff as u64)).await?;
return Ok(Box::pin(file));
}
Err(e) => {
let is_not_found = e.chain().any(|e| {
if let Some(e) = e.downcast_ref::<io::Error>() {
e.kind() == io::ErrorKind::NotFound
} else {
false
}
});
if !is_not_found {
return Err(e);
}
// NotFound is expected, fall through to remote read
}
};
}
// Try to open remote file, if remote reads are enabled
@@ -819,31 +832,26 @@ pub(crate) async fn open_wal_file(
timeline_dir: &Utf8Path,
segno: XLogSegNo,
wal_seg_size: usize,
) -> Result<Option<(tokio::fs::File, bool)>> {
) -> Result<(tokio::fs::File, bool)> {
let (wal_file_path, wal_file_partial_path) = wal_file_paths(timeline_dir, segno, wal_seg_size);
// First try to open the .partial file.
let mut partial_path = wal_file_path.to_owned();
partial_path.set_extension("partial");
if let Ok(opened_file) = tokio::fs::File::open(&wal_file_partial_path).await {
return Ok(Some((opened_file, true)));
return Ok((opened_file, true));
}
// If that failed, try it without the .partial extension.
let pf_res = tokio::fs::File::open(&wal_file_path).await;
if let Err(e) = &pf_res {
if e.kind() == ErrorKind::NotFound {
return Ok(None);
}
}
let pf = pf_res
let pf = tokio::fs::File::open(&wal_file_path)
.await
.with_context(|| format!("failed to open WAL file {wal_file_path:#}"))
.map_err(|e| {
warn!("{e}");
warn!("{}", e);
e
})?;
Ok(Some((pf, false)))
Ok((pf, false))
}
/// Helper returning full path to WAL segment file and its .partial brother.

View File

@@ -27,7 +27,6 @@ governor.workspace = true
hex.workspace = true
hyper0.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
json-structural-diff.workspace = true
lasso.workspace = true
@@ -35,7 +34,6 @@ once_cell.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_connection.workspace = true
posthog_client_lite.workspace = true
rand.workspace = true
reqwest = { workspace = true, features = ["stream"] }
routerify.workspace = true

View File

@@ -14,13 +14,11 @@ use http_utils::tls_certs::ReloadingCertificateResolver;
use hyper0::Uri;
use metrics::BuildInfo;
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::Persistence;
use storage_controller::service::chaos_injector::ChaosInjector;
use storage_controller::service::feature_flag::FeatureFlagService;
use storage_controller::service::{
Config, HEARTBEAT_INTERVAL_DEFAULT, LONG_RECONCILE_THRESHOLD_DEFAULT,
MAX_OFFLINE_INTERVAL_DEFAULT, MAX_WARMING_UP_INTERVAL_DEFAULT,
@@ -254,8 +252,6 @@ struct Secrets {
peer_jwt_token: Option<String>,
}
const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG";
impl Secrets {
const DATABASE_URL_ENV: &'static str = "DATABASE_URL";
const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN";
@@ -413,18 +409,6 @@ async fn async_main() -> anyhow::Result<()> {
None => Vec::new(),
};
let posthog_config = if let Ok(json) = std::env::var(POSTHOG_CONFIG_ENV) {
let res: Result<PostHogConfig, _> = serde_json::from_str(&json);
if let Ok(config) = res {
Some(config)
} else {
tracing::warn!("Invalid posthog config: {json}");
None
}
} else {
None
};
let config = Config {
pageserver_jwt_token: secrets.pageserver_jwt_token,
safekeeper_jwt_token: secrets.safekeeper_jwt_token,
@@ -471,7 +455,6 @@ async fn async_main() -> anyhow::Result<()> {
timelines_onto_safekeepers: args.timelines_onto_safekeepers,
use_local_compute_notifications: args.use_local_compute_notifications,
timeline_safekeeper_count: args.timeline_safekeeper_count,
posthog_config: posthog_config.clone(),
#[cfg(feature = "testing")]
kick_secondary_downloads: args.kick_secondary_downloads,
};
@@ -554,29 +537,6 @@ async fn async_main() -> anyhow::Result<()> {
)
});
let feature_flag_task = if let Some(posthog_config) = posthog_config {
let service = service.clone();
let cancel = CancellationToken::new();
let cancel_bg = cancel.clone();
let task = tokio::task::spawn(
async move {
match FeatureFlagService::new(service, posthog_config) {
Ok(feature_flag_service) => {
let feature_flag_service = Arc::new(feature_flag_service);
feature_flag_service.run(cancel_bg).await
}
Err(e) => {
tracing::warn!("Failed to create feature flag service: {}", e);
}
};
}
.instrument(tracing::info_span!("feature_flag_service")),
);
Some((task, cancel))
} else {
None
};
// Wait until we receive a signal
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
let mut sigquit = tokio::signal::unix::signal(SignalKind::quit())?;
@@ -624,12 +584,6 @@ async fn async_main() -> anyhow::Result<()> {
chaos_jh.await.ok();
}
// If we were running the feature flag service, stop that so that we're not calling into Service while it shuts down
if let Some((feature_flag_task, feature_flag_cancel)) = feature_flag_task {
feature_flag_cancel.cancel();
feature_flag_task.await.ok();
}
service.shutdown().await;
tracing::info!("Service shutdown complete");

View File

@@ -376,13 +376,4 @@ impl PageserverClient {
.await
)
}
pub(crate) async fn update_feature_flag_spec(&self, spec: String) -> Result<()> {
measured_request!(
"update_feature_flag_spec",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.update_feature_flag_spec(spec).await
)
}
}

View File

@@ -1,6 +1,5 @@
pub mod chaos_injector;
mod context_iterator;
pub mod feature_flag;
pub(crate) mod safekeeper_reconciler;
mod safekeeper_service;
@@ -26,7 +25,6 @@ use futures::stream::FuturesUnordered;
use http_utils::error::ApiError;
use hyper::Uri;
use itertools::Itertools;
use pageserver_api::config::PostHogConfig;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability,
NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy,
@@ -473,9 +471,6 @@ pub struct Config {
/// Safekeepers will be choosen from different availability zones.
pub timeline_safekeeper_count: i64,
/// PostHog integration config
pub posthog_config: Option<PostHogConfig>,
#[cfg(feature = "testing")]
pub kick_secondary_downloads: bool,
}

View File

@@ -1,111 +0,0 @@
use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use pageserver_api::config::PostHogConfig;
use pageserver_client::mgmt_api;
use posthog_client_lite::PostHogClient;
use reqwest::StatusCode;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
use crate::{pageserver_client::PageserverClient, service::Service};
pub struct FeatureFlagService {
service: Arc<Service>,
config: PostHogConfig,
client: PostHogClient,
http_client: reqwest::Client,
}
const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
impl FeatureFlagService {
pub fn new(service: Arc<Service>, config: PostHogConfig) -> Result<Self, &'static str> {
let client = PostHogClient::new(config.clone().try_into_posthog_config()?);
Ok(Self {
service,
config,
client,
http_client: reqwest::Client::new(),
})
}
async fn refresh(self: Arc<Self>, cancel: CancellationToken) -> Result<(), anyhow::Error> {
let nodes = {
let inner = self.service.inner.read().unwrap();
inner.nodes.clone()
};
let feature_flag_spec = self.client.get_feature_flags_local_evaluation_raw().await?;
let stream = futures::stream::iter(nodes.values().cloned()).map(|node| {
let this = self.clone();
let feature_flag_spec = feature_flag_spec.clone();
async move {
let res = async {
let client = PageserverClient::new(
node.get_id(),
this.http_client.clone(),
node.base_url(),
// TODO: what if we rotate the token during storcon lifetime?
this.service.config.pageserver_jwt_token.as_deref(),
);
client.update_feature_flag_spec(feature_flag_spec).await?;
tracing::info!(
"Updated {}({}) with feature flag spec",
node.get_id(),
node.base_url()
);
Ok::<_, mgmt_api::Error>(())
};
if let Err(e) = res.await {
if let mgmt_api::Error::ApiError(status, _) = e {
if status == StatusCode::NOT_FOUND {
// This is expected during deployments where the API is not available, so we can ignore it
return;
}
}
tracing::warn!(
"Failed to update feature flag spec for {}: {e}",
node.get_id()
);
}
}
});
let mut stream = stream.buffer_unordered(8);
while stream.next().await.is_some() {
if cancel.is_cancelled() {
return Ok(());
}
}
Ok(())
}
pub async fn run(self: Arc<Self>, cancel: CancellationToken) {
let refresh_interval = self
.config
.refresh_interval
.unwrap_or(DEFAULT_POSTHOG_REFRESH_INTERVAL);
let mut interval = tokio::time::interval(refresh_interval);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
tracing::info!(
"Starting feature flag service with refresh interval: {:?}",
refresh_interval
);
loop {
tokio::select! {
_ = interval.tick() => {}
_ = cancel.cancelled() => {
break;
}
}
let res = self.clone().refresh(cancel.clone()).await;
if let Err(e) = res {
tracing::error!("Failed to refresh feature flags: {e:#?}");
}
}
}
}

View File

@@ -4168,20 +4168,13 @@ class DeletionSubject(Enum):
TENANT = "tenant"
class EmptyTimeline(Enum):
EMPTY = "empty"
NONEMPTY = "nonempty"
@run_only_on_default_postgres("PG version is not interesting here")
@pytest.mark.parametrize("restart_storcon", [RestartStorcon.RESTART, RestartStorcon.ONLINE])
@pytest.mark.parametrize("deletetion_subject", [DeletionSubject.TENANT, DeletionSubject.TIMELINE])
@pytest.mark.parametrize("empty_timeline", [EmptyTimeline.EMPTY, EmptyTimeline.NONEMPTY])
def test_storcon_create_delete_sk_down(
neon_env_builder: NeonEnvBuilder,
restart_storcon: RestartStorcon,
deletetion_subject: DeletionSubject,
empty_timeline: EmptyTimeline,
):
"""
Test that the storcon can create and delete tenants and timelines with a safekeeper being down.
@@ -4233,11 +4226,10 @@ def test_storcon_create_delete_sk_down(
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
if empty_timeline == EmptyTimeline.NONEMPTY:
with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep:
# endpoint should start.
ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3])
ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)")
env.storage_controller.assert_log_contains("writing pending op for sk id 1")
env.safekeepers[0].start()