Compare commits

...

30 Commits

Author SHA1 Message Date
Vlad Lazar
a1cc1f33dc Merge remote-tracking branch 'origin/main' into vlad/hadron-jwt 2025-07-31 11:29:07 +01:00
Alexey Kondratov
8fe7596120 chore(compute_tools): Delete unused anon_ext_fn_reassign.sql (#12787)
It's an anon v1 failed launch artifact, I suppose.
2025-07-31 10:11:30 +00:00
Krzysztof Szafrański
f3ee6e818d [proxy] Correctly classify ConnectErrors (#12793)
As is, e.g. quota errors on wake compute are logged as "compute" errors.
2025-07-31 09:53:48 +00:00
Dmitrii Kovalkov
edd60730c8 safekeeper: use last_log_term in mconf switch + choose most advanced sk in pull timeline (#12778)
## Problem
I discovered two bugs corresponding to safekeeper migration, which
together might lead to a data loss during the migration. The second bug
is from a hadron patch and might lead to a data loss during the
safekeeper restore in hadron as well.

1. `switch_membership` returns the current `term` instead of
`last_log_term`. It is used to choose the `sync_position` in the
algorithm, so we might choose the wrong one and break the correctness
guarantees.
2. The current `term` is used to choose the most advanced SK in
`pull_timeline` with higher priority than `flush_lsn`. It is incorrect
because the most advanced safekeeper is the one with the highest
`(last_log_term, flush_lsn)` pair. The compute might bump term on the
least advanced sk, making it the best choice to pull from, and thus
making committed log entries "uncommitted" after `pull_timeline`

Part of https://databricks.atlassian.net/browse/LKB-1017

## Summary of changes
- Return `last_log_term` in `switch_membership`
- Use `(last_log_term, flush_lsn)` as a primary key for choosing the
most advanced sk in `pull_timeline` and deny pulling if the `max_term`
is higher than on the most advanced sk (hadron only)
- Write tests for both cases
- Retry `sync_safekeepers` in `compute_ctl`
- Take into the account the quorum size when calculating `sync_position`
2025-07-31 09:29:25 +00:00
Aleksandr Sarantsev
975b95f4cd Introduce deletion API improvement RFC (#12484)
## Problem

The deletion logic had become difficult to understand and maintain.

## Summary of changes

- Added an RFC detailing proposed improvements to all deletion-related
APIs.

---------

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-31 08:34:47 +00:00
Mikhail
01c39f378e prewarm cancellation (#12785)
Add DELETE /lfc/prewarm route which handles ongoing prewarm
cancellation, update API spec, add prewarm Cancelled state
Add offload Cancelled state when LFC is not initialized
2025-07-30 22:05:51 +00:00
Dimitri Fontaine
4d3b28bd2e [Hadron] Always run databricks auth hook. (#12683) 2025-07-30 21:34:30 +00:00
Heikki Linnakangas
81ddd10be6 tests: Don't print Hostname on every test connection (#12782)
These lines are a significant fraction of the total log size of the
regression tests. And it seems very uninteresting, it's always
'localhost' in local tests.
2025-07-30 19:56:22 +00:00
Suhas Thalanki
e470997627 enable tests introduced in hadron commits (#12790)
Enables skipped tests introduced in hadron integration commits
2025-07-30 19:10:33 +00:00
Vlad Lazar
94dc55f405 chore: hakari 2025-07-29 10:07:09 +01:00
Vlad Lazar
50ed144689 fixup: don't create unused token for safeekeepers 2025-07-29 09:45:47 +01:00
Vlad Lazar
7de0e326a3 sq 2025-07-29 09:45:40 +01:00
Vlad Lazar
88b260bfc7 Merge remote-tracking branch 'origin' into vlad/hadron-jwt 2025-07-29 09:43:09 +01:00
Vlad Lazar
3bf55c8e93 review: bail out instead of panicking 2025-07-25 12:22:25 +01:00
Vlad Lazar
688d0771d3 review: validate that neon and hadron tokens aren't mixed 2025-07-25 12:15:08 +01:00
Vlad Lazar
8f7314c429 fixup: add OpenSSL license back to the allow list 2025-07-24 15:34:03 +01:00
Vlad Lazar
9d8a3c518b fixup: format doc comment 2025-07-24 15:33:18 +01:00
Vlad Lazar
c63b6c5bd3 chore: cargo hakari 2025-07-24 13:38:37 +01:00
Vlad Lazar
00699d86a2 fixup: bring back the SK peer jwt token 2025-07-24 13:31:07 +01:00
Vlad Lazar
10da740e65 fixup: pylints 2025-07-23 19:41:30 +01:00
Vlad Lazar
84dcfa26bb fixup: endpoint storage tests 2025-07-23 19:39:30 +01:00
Vlad Lazar
382ab511a6 Merge remote-tracking branch 'origin' into vlad/hadron-jwt 2025-07-23 19:15:31 +01:00
Vlad Lazar
2e8eeb3b50 fixup: put pg versions back 2025-07-23 19:14:04 +01:00
Vlad Lazar
bcecb03d2d fixup: bang it into shape 2025-07-23 15:58:43 +01:00
Vlad Lazar
3c5fad0184 sq 2025-07-22 18:00:14 +01:00
Vlad Lazar
9ab3203776 sq 2025-07-22 18:00:03 +01:00
Vlad Lazar
b762de56ff fixup: make it build 2025-07-22 15:49:55 +01:00
William Huang
2ddf8f64ce Augment the JwtAuth utility to support RS256 signatures and extracting decoding keys from X509 certificates (#165) 2025-07-22 12:29:45 +01:00
Vlad Lazar
f0ac89ff6f sq 2025-07-22 12:24:59 +01:00
William Huang
9661022e34 Enable JWT auth in Hadron API endpoints accepting untrusted connections (#179) 2025-07-22 12:23:57 +01:00
60 changed files with 1807 additions and 315 deletions

215
Cargo.lock generated
View File

@@ -173,6 +173,45 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048"
dependencies = [
"asn1-rs-derive",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "asn1-rs-derive"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
"synstructure",
]
[[package]]
name = "asn1-rs-impl"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@@ -307,6 +346,30 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be"
dependencies = [
"aws-lc-sys",
"untrusted 0.7.1",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "aws-runtime"
version = "1.4.4"
@@ -968,6 +1031,29 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
"which",
]
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -1260,6 +1346,15 @@ dependencies = [
"replace_with",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1492,6 +1587,7 @@ dependencies = [
"postgres_connection",
"regex",
"reqwest",
"rsa",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
@@ -1511,6 +1607,7 @@ dependencies = [
"utils",
"whoami",
"workspace_hack",
"x509-parser",
]
[[package]]
@@ -1836,6 +1933,20 @@ dependencies = [
"zeroize",
]
[[package]]
name = "der-parser"
version = "9.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
dependencies = [
"asn1-rs",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "der_derive"
version = "0.7.3"
@@ -1992,6 +2103,12 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2109,6 +2226,7 @@ dependencies = [
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"postgres_backend",
"prometheus",
"rand 0.9.1",
"remote_storage",
@@ -2391,6 +2509,12 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2840,6 +2964,15 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -3614,6 +3747,12 @@ dependencies = [
"spin",
]
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
@@ -4189,6 +4328,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "oid-registry"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9"
dependencies = [
"asn1-rs",
]
[[package]]
name = "once_cell"
version = "1.20.2"
@@ -5072,7 +5220,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"bytes",
"crc32c",
"criterion",
@@ -5737,6 +5885,7 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"aws-lc-rs",
"pem",
"ring",
"rustls-pki-types",
@@ -6052,7 +6201,7 @@ dependencies = [
"cfg-if",
"getrandom 0.2.11",
"libc",
"untrusted",
"untrusted 0.9.0",
"windows-sys 0.52.0",
]
@@ -6173,6 +6322,15 @@ dependencies = [
"semver",
]
[[package]]
name = "rusticata-macros"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [
"nom",
]
[[package]]
name = "rustix"
version = "0.38.41"
@@ -6300,7 +6458,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6311,7 +6469,7 @@ checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6322,7 +6480,7 @@ checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -6484,7 +6642,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring",
"untrusted",
"untrusted 0.9.0",
]
[[package]]
@@ -7067,6 +7225,7 @@ dependencies = [
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
"jsonwebtoken",
"lasso",
"measured",
"metrics",
@@ -8241,6 +8400,12 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8340,6 +8505,7 @@ dependencies = [
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"oid-registry",
"once_cell",
"pem",
"pin-project-lite",
@@ -8347,7 +8513,10 @@ dependencies = [
"pprof",
"pq_proto",
"rand 0.9.1",
"rcgen",
"regex",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"scopeguard",
"sentry",
"serde",
@@ -8368,6 +8537,7 @@ dependencies = [
"tracing-utils",
"uuid",
"walkdir",
"x509-parser",
]
[[package]]
@@ -8482,7 +8652,7 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"postgres_ffi",
"utils",
]
@@ -8647,6 +8817,18 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "whoami"
version = "1.5.1"
@@ -9019,6 +9201,7 @@ dependencies = [
"der 0.7.8",
"deranged",
"digest",
"displaydoc",
"ecdsa 0.16.9",
"either",
"elliptic-curve 0.13.8",
@@ -9066,6 +9249,7 @@ dependencies = [
"prost 0.13.5",
"quote",
"rand 0.9.1",
"rcgen",
"regex",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
@@ -9151,6 +9335,23 @@ dependencies = [
"zeroize",
]
[[package]]
name = "x509-parser"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69"
dependencies = [
"asn1-rs",
"data-encoding",
"der-parser",
"lazy_static",
"nom",
"oid-registry",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "xattr"
version = "1.0.0"

View File

@@ -142,6 +142,7 @@ nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket"
notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
oid-registry = "0.7.1"
once_cell = "1.13"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
@@ -173,6 +174,7 @@ rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
rustls-split = "0.3"
scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
@@ -235,6 +237,7 @@ rustls-native-certs = "0.8"
whoami = "1.5.1"
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
x509-parser = "0.16"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"

View File

@@ -32,8 +32,12 @@ use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::{spawn, sync::watch, task::JoinHandle, time};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::backoff::{
DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, exponential_backoff_duration,
};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
@@ -192,6 +196,7 @@ pub struct ComputeState {
pub startup_span: Option<tracing::span::Span>,
pub lfc_prewarm_state: LfcPrewarmState,
pub lfc_prewarm_token: CancellationToken,
pub lfc_offload_state: LfcOffloadState,
/// WAL flush LSN that is set after terminating Postgres and syncing safekeepers if
@@ -217,6 +222,7 @@ impl ComputeState {
lfc_offload_state: LfcOffloadState::default(),
terminate_flush_lsn: None,
promote_state: None,
lfc_prewarm_token: CancellationToken::new(),
}
}
@@ -1554,6 +1560,41 @@ impl ComputeNode {
Ok(lsn)
}
fn sync_safekeepers_with_retries(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
let max_retries = 5;
let mut attempts = 0;
loop {
let result = self.sync_safekeepers(storage_auth_token.clone());
match &result {
Ok(_) => {
if attempts > 0 {
tracing::info!("sync_safekeepers succeeded after {attempts} retries");
}
return result;
}
Err(e) if attempts < max_retries => {
tracing::info!(
"sync_safekeepers failed, will retry (attempt {attempts}): {e:#}"
);
}
Err(err) => {
tracing::warn!(
"sync_safekeepers still failed after {attempts} retries, giving up: {err:?}"
);
return result;
}
}
// sleep and retry
let backoff = exponential_backoff_duration(
attempts,
DEFAULT_BASE_BACKOFF_SECONDS,
DEFAULT_MAX_BACKOFF_SECONDS,
);
std::thread::sleep(backoff);
attempts += 1;
}
}
/// Do all the preparations like PGDATA directory creation, configuration,
/// safekeepers sync, basebackup, etc.
#[instrument(skip_all)]
@@ -1589,7 +1630,7 @@ impl ComputeNode {
lsn
} else {
info!("starting safekeepers syncing");
self.sync_safekeepers(pspec.storage_auth_token.clone())
self.sync_safekeepers_with_retries(pspec.storage_auth_token.clone())
.with_context(|| "failed to sync safekeepers")?
};
info!("safekeepers synced at LSN {}", lsn);

View File

@@ -7,7 +7,8 @@ use http::StatusCode;
use reqwest::Client;
use std::mem::replace;
use std::sync::Arc;
use tokio::{io::AsyncReadExt, spawn};
use tokio::{io::AsyncReadExt, select, spawn};
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
#[derive(serde::Serialize, Default)]
@@ -92,34 +93,35 @@ impl ComputeNode {
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
/// Has a failpoint "compute-prewarm"
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
let token: CancellationToken;
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
if let LfcPrewarmState::Prewarming = replace(state, LfcPrewarmState::Prewarming) {
let state = &mut self.state.lock().unwrap();
token = state.lfc_prewarm_token.clone();
if let LfcPrewarmState::Prewarming =
replace(&mut state.lfc_prewarm_state, LfcPrewarmState::Prewarming)
{
return false;
}
}
crate::metrics::LFC_PREWARMS.inc();
let cloned = self.clone();
let this = self.clone();
spawn(async move {
let state = match cloned.prewarm_impl(from_endpoint).await {
Ok(true) => LfcPrewarmState::Completed,
Ok(false) => {
info!(
"skipping LFC prewarm because LFC state is not found in endpoint storage"
);
LfcPrewarmState::Skipped
}
let prewarm_state = match this.prewarm_impl(from_endpoint, token).await {
Ok(state) => state,
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: format!("{err:#}"),
}
let error = format!("{err:#}");
LfcPrewarmState::Failed { error }
}
};
cloned.state.lock().unwrap().lfc_prewarm_state = state;
let state = &mut this.state.lock().unwrap();
if let LfcPrewarmState::Cancelled = prewarm_state {
state.lfc_prewarm_token = CancellationToken::new();
}
state.lfc_prewarm_state = prewarm_state;
});
true
}
@@ -132,47 +134,70 @@ impl ComputeNode {
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
async fn prewarm_impl(
&self,
from_endpoint: Option<String>,
token: CancellationToken,
) -> Result<LfcPrewarmState> {
let EndpointStoragePair {
url,
token: storage_token,
} = self.endpoint_storage_pair(from_endpoint)?;
#[cfg(feature = "testing")]
fail::fail_point!("compute-prewarm", |_| {
bail!("prewarm configured to fail because of a failpoint")
});
fail::fail_point!("compute-prewarm", |_| bail!("compute-prewarm failpoint"));
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
match res.status() {
let request = Client::new().get(&url).bearer_auth(storage_token);
let response = select! {
_ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled),
response = request.send() => response
}
.context("querying endpoint storage")?;
match response.status() {
StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
StatusCode::NOT_FOUND => return Ok(LfcPrewarmState::Skipped),
status => bail!("{status} querying endpoint storage"),
}
let mut uncompressed = Vec::new();
let lfc_state = res
.bytes()
.await
.context("getting request body from endpoint storage")?;
ZstdDecoder::new(lfc_state.iter().as_slice())
.read_to_end(&mut uncompressed)
.await
.context("decoding LFC state")?;
let lfc_state = select! {
_ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled),
lfc_state = response.bytes() => lfc_state
}
.context("getting request body from endpoint storage")?;
let mut decoder = ZstdDecoder::new(lfc_state.iter().as_slice());
select! {
_ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled),
read = decoder.read_to_end(&mut uncompressed) => read
}
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
// Client connection and prewarm info querying are fast and therefore don't need
// cancellation
let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())?;
.context("connecting to postgres")?;
let pg_token = client.cancel_token();
Ok(true)
let params: Vec<&(dyn postgres_types::ToSql + Sync)> = vec![&uncompressed];
select! {
res = client.query_one("select neon.prewarm_local_cache($1)", &params) => res,
_ = token.cancelled() => {
pg_token.cancel_query(postgres::NoTls).await
.context("cancelling neon.prewarm_local_cache()")?;
return Ok(LfcPrewarmState::Cancelled)
}
}
.context("loading LFC state into postgres")
.map(|_| ())?;
Ok(LfcPrewarmState::Completed)
}
/// If offload request is ongoing, return false, true otherwise
@@ -200,20 +225,20 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
let state = match self.offload_lfc_impl().await {
Ok(state) => state,
Err(err) => {
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "could not offload LFC");
let error = format!("{err:#}");
LfcOffloadState::Failed { error }
}
};
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "could not offload LFC state to endpoint storage");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: format!("{err:#}"),
};
self.state.lock().unwrap().lfc_offload_state = state;
}
async fn offload_lfc_impl(&self) -> Result<()> {
async fn offload_lfc_impl(&self) -> Result<LfcOffloadState> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from Postgres");
@@ -228,7 +253,7 @@ impl ComputeNode {
.context("deserializing LFC state")?;
let Some(state) = state else {
info!(%url, "empty LFC state, not exporting");
return Ok(());
return Ok(LfcOffloadState::Skipped);
};
let mut compressed = Vec::new();
@@ -242,7 +267,7 @@ impl ComputeNode {
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) if res.status() == StatusCode::OK => Ok(LfcOffloadState::Completed),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
@@ -250,4 +275,8 @@ impl ComputeNode {
Err(err) => Err(err).context("writing to endpoint storage"),
}
}
pub fn cancel_prewarm(self: &Arc<Self>) {
self.state.lock().unwrap().lfc_prewarm_token.cancel();
}
}

View File

@@ -139,6 +139,15 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/LfcPrewarmState"
delete:
tags:
- Prewarm
summary: Cancel ongoing LFC prewarm
description: ""
operationId: cancelLfcPrewarm
responses:
202:
description: Prewarm cancelled
/lfc/offload:
post:
@@ -636,7 +645,7 @@ components:
properties:
status:
description: LFC offload status
enum: [not_offloaded, offloading, completed, failed]
enum: [not_offloaded, offloading, completed, skipped, failed]
type: string
error:
description: LFC offload error, if any

View File

@@ -46,3 +46,8 @@ pub(in crate::http) async fn offload(compute: Compute) -> Response {
)
}
}
pub(in crate::http) async fn cancel_prewarm(compute: Compute) -> StatusCode {
compute.cancel_prewarm();
StatusCode::ACCEPTED
}

View File

@@ -99,7 +99,12 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
);
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))
.route(
"/lfc/prewarm",
get(lfc::prewarm_state)
.post(lfc::prewarm)
.delete(lfc::cancel_prewarm),
)
.route("/lfc/offload", get(lfc::offload_state).post(lfc::offload))
.route("/promote", post(promote::promote))
.route("/check_writability", post(check_writability::is_writable))

View File

@@ -1,13 +0,0 @@
DO $$
DECLARE
query varchar;
BEGIN
FOR query IN
SELECT pg_catalog.format('ALTER FUNCTION %I(%s) OWNER TO {db_owner};', p.oid::regproc, pg_catalog.pg_get_function_identity_arguments(p.oid))
FROM pg_catalog.pg_proc p
WHERE p.pronamespace OPERATOR(pg_catalog.=) 'anon'::regnamespace::oid
LOOP
EXECUTE query;
END LOOP;
END
$$;

View File

@@ -46,3 +46,5 @@ endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true
x509-parser.workspace = true
rsa = "0.9"

View File

@@ -1049,6 +1049,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
// User (likely interactive) did not provide a description of the environment, give them the default
NeonLocalInitConf {
control_plane_api: Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap()),
auth_token_type: AuthType::NeonJWT,
broker: NeonBroker {
listen_addr: Some(DEFAULT_BROKER_ADDR.parse().unwrap()),
listen_https_addr: None,
@@ -1584,7 +1585,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
pageserver_conninfo.prefer_protocol = prefer_protocol;
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(
ps_conf.pg_auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
) {
let claims = Claims::new(Some(endpoint.tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -37,18 +37,8 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -66,20 +56,30 @@ pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionI
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType,
};
use nix::sys::signal::{Signal, kill};
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts};
use safekeeper_api::PgMajorVersion;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::debug;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use x509_parser::parse_x509_certificate;
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
@@ -161,23 +161,76 @@ impl ComputeControlPlane {
.unwrap_or(self.base_port)
}
// BEGIN HADRON
/// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key
fn extract_spki_from_pem(pem: &Pem) -> Result<Vec<u8>> {
if pem.tag() == "CERTIFICATE" {
// Handle X509 certificate
let (_, cert) = parse_x509_certificate(pem.contents())?;
let public_key = cert.public_key();
Ok(public_key.subject_public_key.data.to_vec())
} else {
// Handle public key directly
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
}
/// Create RSA JWK from certificate PEM
fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result<Jwk> {
let public_key = Self::extract_spki_from_pem(pem)?;
// Extract RSA parameters (n, e) from RSA public key DER data
let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?;
let n = rsa_key.n().to_bytes_be();
let e = rsa_key.e().to_bytes_be();
Ok(Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::RS256),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
x509_sha256_fingerprint: None::<String>,
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
key_type: RSAKeyType::RSA,
n: URL_SAFE_NO_PAD.encode(n),
e: URL_SAFE_NO_PAD.encode(e),
}),
})
}
// END HADRON
/// Create a JSON Web Key Set. This ideally matches the way we create a JWKS
/// from the production control plane.
fn create_jwks_from_pem(pem: &Pem) -> Result<JwkSet> {
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
let public_key = spki.subject_public_key.raw_bytes();
let public_key = Self::extract_spki_from_pem(pem)?;
let mut hasher = Sha256::new();
hasher.update(public_key);
hasher.update(&public_key);
let key_hash = hasher.finalize();
// BEGIN HADRON
if pem.tag() == "CERTIFICATE" {
// Assume RSA if we are parsing keys from a certificate.
let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?;
return Ok(JwkSet { keys: vec![jwk] });
}
// END HADRON
Ok(JwkSet {
keys: vec![Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -186,7 +239,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: URL_SAFE_NO_PAD.encode(public_key),
}),
}],
})

View File

@@ -2,6 +2,7 @@ use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
@@ -16,15 +17,22 @@ pub struct EndpointStorage {
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub auth_type: AuthType,
}
impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
let auth_type = match env.token_auth_type {
AuthType::HadronJWT => AuthType::HadronJWT,
AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT,
};
EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
auth_type,
}
}
@@ -46,12 +54,14 @@ impl EndpointStorage {
pemfile: Utf8PathBuf,
local_path: Utf8PathBuf,
r#type: String,
auth_type: AuthType,
}
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
auth_type: self.auth_type,
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
std::fs::write(self.config_path(), serde_json::to_string(&cfg)?)

View File

@@ -18,7 +18,7 @@ use postgres_backend::AuthType;
use reqwest::{Certificate, Url};
use safekeeper_api::PgMajorVersion;
use serde::{Deserialize, Serialize};
use utils::auth::encode_from_key_file;
use utils::auth::{encode_from_key_file, encode_hadron_token};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::broker::StorageBroker;
@@ -60,6 +60,9 @@ pub struct LocalEnv {
// --tenant_id is not explicitly specified.
pub default_tenant_id: Option<TenantId>,
// The type of tokens to use for authentication in the test environment. Determines
// the type of key pairs and tokens generated in the test.
pub token_auth_type: AuthType,
// used to issue tokens during e.g pg start
pub private_key_path: PathBuf,
/// Path to environment's public key
@@ -105,6 +108,7 @@ pub struct OnDiskConfig {
pub pg_distrib_dir: PathBuf,
pub neon_distrib_dir: PathBuf,
pub default_tenant_id: Option<TenantId>,
pub token_auth_type: Option<AuthType>,
pub private_key_path: PathBuf,
pub public_key_path: PathBuf,
pub broker: NeonBroker,
@@ -153,6 +157,7 @@ pub struct NeonLocalInitConf {
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
pub auth_token_type: AuthType,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -374,7 +379,7 @@ pub struct SafekeeperConf {
pub sync: bool,
pub remote_storage: Option<String>,
pub backup_threads: Option<u32>,
pub auth_enabled: bool,
pub auth_type: AuthType,
pub listen_addr: Option<String>,
}
@@ -389,7 +394,7 @@ impl Default for SafekeeperConf {
sync: true,
remote_storage: None,
backup_threads: None,
auth_enabled: false,
auth_type: AuthType::Trust,
listen_addr: None,
}
}
@@ -663,6 +668,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type,
private_key_path,
public_key_path,
broker,
@@ -681,6 +687,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type: token_auth_type.unwrap_or(AuthType::NeonJWT),
private_key_path,
public_key_path,
broker,
@@ -796,6 +803,7 @@ impl LocalEnv {
pg_distrib_dir: self.pg_distrib_dir.clone(),
neon_distrib_dir: self.neon_distrib_dir.clone(),
default_tenant_id: self.default_tenant_id,
token_auth_type: Some(self.token_auth_type),
private_key_path: self.private_key_path.clone(),
public_key_path: self.public_key_path.clone(),
broker: self.broker.clone(),
@@ -825,8 +833,18 @@ impl LocalEnv {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn generate_auth_token<S: Serialize>(&self, claims: &S) -> anyhow::Result<String> {
let key = self.read_private_key()?;
encode_from_key_file(claims, &key)
match self.token_auth_type {
AuthType::NeonJWT => {
let key_data = self.read_private_key()?;
encode_from_key_file(claims, &key_data)
}
AuthType::HadronJWT => {
let private_key_path = self.get_private_key_path();
let key_data = fs::read(private_key_path)?;
encode_hadron_token(claims, &key_data)
}
_ => panic!("unsupported token auth type {:?}", self.token_auth_type),
}
}
/// Get the path to the private key.
@@ -915,6 +933,7 @@ impl LocalEnv {
generate_local_ssl_certs,
control_plane_hooks_api,
endpoint_storage,
auth_token_type,
} = conf;
// Find postgres binaries.
@@ -943,6 +962,7 @@ impl LocalEnv {
generate_auth_keys(
base_path.join("auth_private_key.pem").as_path(),
base_path.join("auth_public_key.pem").as_path(),
auth_token_type,
)
.context("generate auth keys")?;
let private_key_path = PathBuf::from("auth_private_key.pem");
@@ -956,6 +976,7 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id: Some(default_tenant_id),
token_auth_type: auth_token_type,
private_key_path,
public_key_path,
broker,
@@ -1035,39 +1056,63 @@ pub fn base_path() -> PathBuf {
}
/// Generate a public/private key pair for JWT authentication
fn generate_auth_keys(private_key_path: &Path, public_key_path: &Path) -> anyhow::Result<()> {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
fn generate_auth_keys(
private_key_path: &Path,
public_key_path: &Path,
auth_type: AuthType,
) -> anyhow::Result<()> {
if auth_type == AuthType::NeonJWT {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
} else if auth_type == AuthType::HadronJWT {
// Generate the RSA key pair. Note that the public key is embedded in an X509 certificate.
//
// openssl req -x509 -newkey rsa:4096 -keyout auth_private_key.pem -out auth_public_key.pem -nodes -subj "/CN=eng-brickstore@databricks.com"
let keygen_output = Command::new("openssl")
.arg("req")
.args(["-x509", "-newkey", "rsa:4096", "-sha256"])
.args(["-keyout", private_key_path.to_str().unwrap()])
.args(["-out", public_key_path.to_str().unwrap()])
.args(["-nodes"])
.args(["-subj", "/CN=eng-brickstore@databricks.com"])
.output()
.context("Failed to generate RSA key pair for Hadron token auth")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
}
Ok(())

View File

@@ -73,7 +73,7 @@ impl PageServerNode {
{
match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(
AuthType::NeonJWT | AuthType::HadronJWT => Some(
env.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap(),
),
@@ -117,7 +117,10 @@ impl PageServerNode {
// Storage controller uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
// Note: In Hadron the "control plane" is HCC. HCC does not require a token on the trusted port PS connects
// to, so we do not need to set any tokens when using HadronJWT. In the future we may consider using mTLS
// instead of JWT for HTTP auth.
if matches!(conf.http_auth_type, AuthType::NeonJWT | AuthType::HadronJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
@@ -132,7 +135,8 @@ impl PageServerNode {
}
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.contains(&AuthType::NeonJWT)
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
// Keys are generated in the toplevel repo dir, pageservers' workdirs
// are one level below that, so refer to keys with ../

View File

@@ -13,6 +13,7 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use postgres_connection::PgConnectionConfig;
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
@@ -110,7 +111,7 @@ impl SafekeeperNode {
}
// Generate a token file for authentication with other safekeepers
if self.conf.auth_enabled {
if self.conf.auth_type != AuthType::Trust {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::SafekeeperData))?;
@@ -156,7 +157,7 @@ impl SafekeeperNode {
"--id".to_owned(),
id_string,
"--listen-pg".to_owned(),
listen_pg,
listen_pg.clone(),
"--listen-http".to_owned(),
listen_http,
"--availability-zone".to_owned(),
@@ -186,7 +187,11 @@ impl SafekeeperNode {
}
let key_path = self.env.base_data_dir.join("auth_public_key.pem");
if self.conf.auth_enabled {
if self.conf.auth_type != AuthType::Trust {
args.extend([
"--token-auth-type".to_owned(),
self.conf.auth_type.to_string(),
]);
let key_path_string = key_path
.to_str()
.with_context(|| {
@@ -205,6 +210,15 @@ impl SafekeeperNode {
"--http-auth-public-key-path".to_owned(),
key_path_string.clone(),
]);
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
if let Some(https_port) = self.conf.https_port {
@@ -217,26 +231,14 @@ impl SafekeeperNode {
args.push(format!("--ssl-ca-file={}", ssl_ca_file.to_str().unwrap()));
}
if self.conf.auth_enabled {
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
args.extend_from_slice(extra_opts);
let env_variables = Vec::new();
background_process::start_process(
&format!("safekeeper-{id}"),
&datadir,
&self.env.safekeeper_bin(),
&args,
env_variables,
self.safekeeper_env_variables()?,
background_process::InitialPidFile::Expect(self.pid_file()),
retry_timeout,
|| async {
@@ -250,6 +252,11 @@ impl SafekeeperNode {
.await
}
fn safekeeper_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// TODO: remove me
Ok(vec![])
}
///
/// Stop the server.
///

View File

@@ -30,14 +30,14 @@ use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
use url::Url;
use utils::auth::{Claims, Scope, encode_from_key_file};
use utils::auth::{Claims, Scope, encode_from_key_file, encode_hadron_token};
use utils::id::{NodeId, TenantId};
use whoami::username;
pub struct StorageController {
env: LocalEnv,
private_key: Option<Pem>,
public_key: Option<Pem>,
private_key: Option<StorageControllerPrivateKey>,
public_key: Option<StorageControllerPublicKey>,
client: reqwest::Client,
config: NeonStorageControllerConf,
@@ -108,6 +108,25 @@ pub struct InspectResponse {
pub attachment: Option<(u32, NodeId)>,
}
enum StorageControllerPublicKey {
RawPublicKey(Pem),
PublicKeyCertPath(Utf8PathBuf),
}
enum StorageControllerPrivateKey {
EdPrivateKey(Pem),
HadronPrivateKey(Utf8PathBuf, Vec<u8>),
}
impl StorageControllerPrivateKey {
pub fn encode_token(&self, claims: &Claims) -> anyhow::Result<String> {
match self {
Self::EdPrivateKey(key_data) => encode_from_key_file(claims, key_data),
Self::HadronPrivateKey(_, key_data) => encode_hadron_token(claims, key_data),
}
}
}
impl StorageController {
pub fn from_env(env: &LocalEnv) -> Self {
// Assume all pageservers have symmetric auth configuration: this service
@@ -152,7 +171,30 @@ impl StorageController {
)
.expect("Failed to parse PEM file")
};
(Some(private_key), Some(public_key))
(
Some(StorageControllerPrivateKey::EdPrivateKey(private_key)),
Some(StorageControllerPublicKey::RawPublicKey(public_key)),
)
}
AuthType::HadronJWT => {
let private_key_path = env.get_private_key_path();
let private_key =
fs::read(private_key_path.clone()).expect("failed to read private key");
// If pageserver auth is enabled, this implicitly enables auth for this service,
// using the same credentials.
let public_key_path =
camino::Utf8PathBuf::try_from(env.base_data_dir.join("auth_public_key.pem"))
.unwrap();
(
Some(StorageControllerPrivateKey::HadronPrivateKey(
camino::Utf8PathBuf::try_from(private_key_path).unwrap(),
private_key,
)),
Some(StorageControllerPublicKey::PublicKeyCertPath(
public_key_path,
)),
)
}
};
@@ -575,23 +617,38 @@ impl StorageController {
if let Some(private_key) = &self.private_key {
let claims = Claims::new(None, Scope::PageServerApi);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
if let StorageControllerPrivateKey::HadronPrivateKey(key_path, _) = private_key {
args.push(format!("--private-key-path={key_path}"));
}
// We are setting all JWT tokens for Hadron as well in this test to avoid bifurcation between Neon and
// Hadron test cases. In production we do not need to set this as HTTP auth is not enabled on the
// pageserver. We use network segmentation to ensure that only trusted components can talk to
// pageserver's http port
let jwt_token = private_key.encode_token(&claims)?;
args.push(format!("--jwt-token={jwt_token}"));
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token = encode_from_key_file(&peer_claims, private_key)
let peer_jwt_token = private_key
.encode_token(&peer_claims)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
let jwt_token = private_key
.encode_token(&claims)
.expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
if let Some(public_key) = &self.public_key {
args.push(format!("--public-key=\"{public_key}\""));
match public_key {
StorageControllerPublicKey::RawPublicKey(public_key) => {
args.push(format!("--public-key=\"{public_key}\""));
}
StorageControllerPublicKey::PublicKeyCertPath(public_key_path) => {
args.push(format!("--public-key-cert-path={public_key_path}"));
}
}
}
if let Some(control_plane_hooks_api) = &self.env.control_plane_hooks_api {
@@ -632,7 +689,13 @@ impl StorageController {
self.env.base_data_dir.display()
));
if self.env.safekeepers.iter().any(|sk| sk.auth_enabled) && self.private_key.is_none() {
if self
.env
.safekeepers
.iter()
.any(|sk| sk.auth_type != AuthType::Trust)
&& self.private_key.is_none()
{
anyhow::bail!("Safekeeper set up for auth but no private key specified");
}
@@ -847,7 +910,7 @@ impl StorageController {
println!("Getting claims for path {path}");
if let Some(required_claims) = Self::get_claims_for_path(&path)? {
println!("Got claims {required_claims:?} for path {path}");
let jwt_token = encode_from_key_file(&required_claims, private_key)?;
let jwt_token = private_key.encode_token(&required_claims)?;
builder = builder.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {jwt_token}"),

View File

@@ -46,6 +46,7 @@ allow = [
"ISC",
"MIT",
"MPL-2.0",
"OpenSSL",
"Unicode-3.0",
]
confidence-threshold = 0.8

View File

@@ -0,0 +1,246 @@
# Node deletion API improvement
Created on 2025-07-07
Implemented on _TBD_
## Summary
This RFC describes improvements to the storage controller API for gracefully deleting pageserver
nodes.
## Motivation
The basic node deletion API introduced in [#8226](https://github.com/neondatabase/neon/issues/8333)
has several limitations:
- Deleted nodes can re-add themselves if they restart (e.g., a flaky node that keeps restarting and
we cannot reach via SSH to stop the pageserver). This issue has been resolved by tombstone
mechanism in [#12036](https://github.com/neondatabase/neon/issues/12036)
- Process of node deletion is not graceful, i.e. it just imitates a node failure
In this context, "graceful" node deletion means that users do not experience any disruption or
negative effects, provided the system remains in a healthy state (i.e., the remaining pageservers
can handle the workload and all requirements are met). To achieve this, the system must perform
live migration of all tenant shards from the node being deleted while the node is still running
and continue processing all incoming requests. The node is removed only after all tenant shards
have been safely migrated.
Although live migrations can be achieved with the drain functionality, it leads to incorrect shard
placement, such as not matching availability zones. This results in unnecessary work to optimize
the placement that was just recently performed.
If we delete a node before its tenant shards are fully moved, the new node won't have all the
needed data (e.g. heatmaps) ready. This means user requests to the new node will be much slower at
first. If there are many tenant shards, this slowdown affects a huge amount of users.
Graceful node deletion is more complicated and can introduce new issues. It takes longer because
live migration of each tenant shard can last several minutes. Using non-blocking accessors may
also cause deletion to wait if other processes are holding inner state lock. It also gets trickier
because we need to handle other requests, like drain and fill, at the same time.
## Impacted components (e.g. pageserver, safekeeper, console, etc)
- storage controller
- pageserver (indirectly)
## Proposed implementation
### Tombstones
To resolve the problem of deleted nodes re-adding themselves, a tombstone mechanism was introduced
as part of the node stored information. Each node has a separate `NodeLifecycle` field with two
possible states: `Active` and `Deleted`. When node deletion completes, the database row is not
deleted but instead has its `NodeLifecycle` column switched to `Deleted`. Nodes with `Deleted`
lifecycle are treated as if the row is absent for most handlers, with several exceptions: reattach
and register functionality must be aware of tombstones. Additionally, new debug handlers are
available for listing and deleting tombstones via the `/debug/v1/tombstone` path.
### Gracefulness
The problem of making node deletion graceful is complex and involves several challenges:
- **Cancellable**: The operation must be cancellable to allow administrators to abort the process
if needed, e.g. if run by mistake.
- **Non-blocking**: We don't want to block deployment operations like draining/filling on the node
deletion process. We need clear policies for handling concurrent operations: what happens when a
drain/fill request arrives while deletion is in progress, and what happens when a delete request
arrives while drain/fill is in progress.
- **Persistent**: If the storage controller restarts during this long-running operation, we must
preserve progress and automatically resume the deletion process after the storage controller
restarts.
- **Migrated correctly**: We cannot simply use the existing drain mechanism for nodes scheduled
for deletion, as this would move shards to irrelevant locations. The drain process expects the
node to return, so it only moves shards to backup locations, not to their preferred AZs. It also
leaves secondary locations unmoved. This could result in unnecessary load on the storage
controller and inefficient resource utilization.
- **Force option**: Administrators need the ability to force immediate, non-graceful deletion when
time constraints or emergency situations require it, bypassing the normal graceful migration
process.
See below for a detailed breakdown of the proposed changes and mechanisms.
#### Node lifecycle
New `NodeLifecycle` enum and a matching database field with these values:
- `Active`: The normal state. All operations are allowed.
- `ScheduledForDeletion`: The node is marked to be deleted soon. Deletion may be in progress or
will happen later, but the node will eventually be removed. All operations are allowed.
- `Deleted`: The node is fully deleted. No operations are allowed, and the node cannot be brought
back. The only action left is to remove its record from the database. Any attempt to register a
node in this state will fail.
This state persists across storage controller restarts.
**State transition**
```
+--------------------+
+---| Active |<---------------------+
| +--------------------+ |
| ^ |
| start_node_delete | cancel_node_delete |
v | |
+----------------------------------+ |
| ScheduledForDeletion | |
+----------------------------------+ |
| |
| node_register |
| |
| delete_node (at the finish) |
| |
v |
+---------+ tombstone_delete +----------+
| Deleted |-------------------------------->| no row |
+---------+ +----------+
```
#### NodeSchedulingPolicy::Deleting
A `Deleting` variant to the `NodeSchedulingPolicy` enum. This means the deletion function is
running for the node right now. Only one node can have the `Deleting` policy at a time.
The `NodeSchedulingPolicy::Deleting` state is persisted in the database. However, after a storage
controller restart, any node previously marked as `Deleting` will have its scheduling policy reset
to `Pause`. The policy will only transition back to `Deleting` when the deletion operation is
actively started again, as triggered by the node's `NodeLifecycle::ScheduledForDeletion` state.
`NodeSchedulingPolicy` transition details:
1. When `node_delete` begins, set the policy to `NodeSchedulingPolicy::Deleting`.
2. If `node_delete` is cancelled (for example, due to a concurrent drain operation), revert the
policy to its previous value. The policy is persisted in storcon DB.
3. After `node_delete` completes, the final value of the scheduling policy is irrelevant, since
`NodeLifecycle::Deleted` prevents any further access to this field.
The deletion process cannot be initiated for nodes currently undergoing deployment-related
operations (`Draining`, `Filling`, or `PauseForRestart` policies). Deletion will only be triggered
once the node transitions to either the `Active` or `Pause` state.
#### OperationTracker
A replacement for `Option<OperationHandler> ongoing_operation`, the `OperationTracker` is a
dedicated service state object responsible for managing all long-running node operations (drain,
fill, delete) with robust concurrency control.
Key responsibilities:
- Orchestrates the execution of operations
- Supports cancellation of currently running operations
- Enforces operation constraints, e.g. allowing only single drain/fill operation at a time
- Persists deletion state, enabling recovery of pending deletions across restarts
- Ensures thread safety across concurrent requests
#### Attached tenant shard processing
When deleting a node, handle each attached tenant shard as follows:
1. Pick the best node to become the new attached (the candidate).
2. If the candidate already has this shard as a secondary:
- Create a new secondary for the shard on another suitable node.
Otherwise:
- Create a secondary for the shard on the candidate node.
3. Wait until all secondaries are ready and pre-warmed.
4. Promote the candidate's secondary to attached.
5. Remove the secondary from the node being deleted.
This process safely moves all attached shards before deleting the node.
#### Secondary tenant shard processing
When deleting a node, handle each secondary tenant shard as follows:
1. Choose the best node to become the new secondary.
2. Create a secondary for the shard on that node.
3. Wait until the new secondary is ready.
4. Remove the secondary from the node being deleted.
This ensures all secondary shards are safely moved before deleting the node.
### Reliability, failure modes and corner cases
In case of a storage controller failure and following restart, the system behavior depends on the
`NodeLifecycle` state:
- If `NodeLifecycle` is `Active`: No action is taken for this node.
- If `NodeLifecycle` is `Deleted`: The node will not be re-added.
- If `NodeLifecycle` is `ScheduledForDeletion`: A deletion background task will be launched for
this node.
In case of a pageserver node failure during deletion, the behavior depends on the `force` flag:
- If `force` is set: The node deletion will proceed regardless of the node's availability.
- If `force` is not set: The deletion will be retried a limited number of times. If the node
remains unavailable, the deletion process will pause and automatically resume when the node
becomes healthy again.
### Operations concurrency
The following sections describe the behavior when different types of requests arrive at the storage
controller and how they interact with ongoing operations.
#### Delete request
Handler: `PUT /control/v1/node/:node_id/delete`
1. If node lifecycle is `NodeLifecycle::ScheduledForDeletion`:
- Return `200 OK`: there is already an ongoing deletion request for this node
2. Update & persist lifecycle to `NodeLifecycle::ScheduledForDeletion`
3. Persist current scheduling policy
4. If there is no active operation (drain/fill/delete):
- Run deletion process for this node
#### Cancel delete request
Handler: `DELETE /control/v1/node/:node_id/delete`
1. If node lifecycle is not `NodeLifecycle::ScheduledForDeletion`:
- Return `404 Not Found`: there is no current deletion request for this node
2. If the active operation is deleting this node, cancel it
3. Update & persist lifecycle to `NodeLifecycle::Active`
4. Restore the last scheduling policy from persistence
#### Drain/fill request
1. If there are already ongoing drain/fill processes:
- Return `409 Conflict`: queueing of drain/fill processes is not supported
2. If there is an ongoing delete process:
- Cancel it and wait until it is cancelled
3. Run the drain/fill process
4. After the drain/fill process is cancelled or finished:
- Try to find another candidate to delete and run the deletion process for that node
#### Drain/fill cancel request
1. If the active operation is not the related process:
- Return `400 Bad Request`: cancellation request is incorrect, operations are not the same
2. Cancel the active operation
3. Try to find another candidate to delete and run the deletion process for that node
## Definition of Done
- [x] Fix flaky node scenario and introduce related debug handlers
- [ ] Node deletion intent is persistent - a node will be eventually deleted after a deletion
request regardless of draining/filling requests and restarts
- [ ] Node deletion can be graceful - deletion completes only after moving all tenant shards to
recommended locations
- [ ] Deploying does not break due to long deletions - drain/fill operations override deletion
process and deletion resumes after drain/fill completes
- [ ] `force` flag is implemented and provides fast, failure-tolerant node removal (e.g., when a
pageserver node does not respond)
- [ ] Legacy delete handler code is removed from storage_controller, test_runner, and storcon_cli

View File

@@ -20,6 +20,7 @@ tokio.workspace = true
tracing.workspace = true
utils = { path = "../libs/utils", default-features = false }
workspace_hack.workspace = true
postgres_backend.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
http-body-util.workspace = true

View File

@@ -206,12 +206,16 @@ mod tests {
use axum::{body::Body, extract::Request, response::Response};
use http_body_util::BodyExt;
use itertools::iproduct;
use jsonwebtoken::DecodingKey;
use std::env::var;
use std::sync::Arc;
use std::time::Duration;
use test_log::test as testlog;
use tower::{Service, util::ServiceExt};
use utils::id::{TenantId, TimelineId};
use utils::{
auth::JwtAuth,
id::{TenantId, TimelineId},
};
// see libs/remote_storage/tests/test_real_s3.rs
const REAL_S3_ENV: &str = "ENABLE_REAL_S3_REMOTE_STORAGE";
@@ -251,7 +255,9 @@ mod tests {
};
let proxy = Storage {
auth: endpoint_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
auth: JwtAuth::new(vec![
DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap(),
]),
storage,
cancel: cancel.clone(),
max_upload_file_limit: usize::MAX,
@@ -352,7 +358,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
@@ -501,7 +507,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}

View File

@@ -7,7 +7,6 @@ use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use camino::Utf8PathBuf;
use jsonwebtoken::{DecodingKey, Validation};
use remote_storage::{GenericRemoteStorage, RemotePath};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@@ -15,28 +14,9 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::auth::JwtAuth;
use utils::id::{EndpointId, TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
impl JwtAuth {
pub fn new(key: &[u8]) -> Result<Self> {
Ok(Self {
decoding_key: DecodingKey::from_ed_pem(key)?,
validation: Validation::new(VALIDATION_ALGO),
})
}
pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
}
}
fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
let key = clean_utf8(&Utf8PathBuf::from(key));
if key.starts_with("..") || key == "." || key == "/" {
@@ -157,7 +137,8 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
let claims: EndpointStorageClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
.map_err(|e| bad_request(e, "decoding token"))?
.claims;
// Read paths may have different endpoint ids. For readonly -> readwrite replica
// prewarming, endpoint must read other endpoint's data.
@@ -224,7 +205,8 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
let claims: DeletePrefixClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
.map_err(|e| bad_request(e, "invalid token"))?
.claims;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,

View File

@@ -5,8 +5,10 @@
mod app;
use anyhow::Context;
use clap::Parser;
use postgres_backend::AuthType;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::auth::JwtAuth;
use utils::logging;
//see set()
@@ -18,6 +20,10 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
const fn default_auth_type() -> AuthType {
AuthType::NeonJWT
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
@@ -39,6 +45,8 @@ struct Config {
storage_kind: remote_storage::TypedRemoteStorageKind,
#[serde(default = "max_upload_file_limit")]
max_upload_file_limit: usize,
#[serde(default = "default_auth_type")]
auth_type: AuthType,
}
#[tokio::main]
@@ -61,10 +69,15 @@ async fn main() -> anyhow::Result<()> {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = endpoint_storage::JwtAuth::new(&pemfile)?;
if config.auth_type == AuthType::Trust {
anyhow::bail!("Trust based auth is not supported");
}
let auth = match config.auth_type {
AuthType::NeonJWT => JwtAuth::from_key_path(&config.pemfile)?,
AuthType::HadronJWT => JwtAuth::from_cert_path(&config.pemfile)?,
AuthType::Trust => unreachable!(),
};
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());

View File

@@ -68,11 +68,15 @@ pub enum LfcPrewarmState {
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
/// This may also happen if LFC is turned off or not initialized
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
/// LFC prewarm was cancelled. Some pages in LFC cache may be prewarmed if query
/// has started working before cancellation
Cancelled,
}
impl Display for LfcPrewarmState {
@@ -83,6 +87,7 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
LfcPrewarmState::Cancelled => f.write_str("Cancelled"),
}
}
}
@@ -97,6 +102,7 @@ pub enum LfcOffloadState {
Failed {
error: String,
},
Skipped,
}
#[derive(Serialize, Debug, Clone, PartialEq)]

View File

@@ -705,8 +705,10 @@ pub fn check_permission_with(
check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
) -> Result<(), ApiError> {
match req.context::<Claims>() {
Some(claims) => Ok(check_permission(&claims)
.map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
Some(claims) => Ok(check_permission(&claims).map_err(|err| {
tracing::info!("Authorization error: {err}");
ApiError::Forbidden("JWT authentication error".to_string())
})?),
None => Ok(()), // claims is None because auth is disabled
}
}

View File

@@ -194,6 +194,10 @@ pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
// Similar to above but uses Hadron JWT. Hadron JWTs are slightly different in that:
// 1. Decoding keys are loaded from PEM-encoded X509 certificates instead of plain key files.
// 2. Signature algorithm is RSA-based (may change in the future).
HadronJWT,
}
impl FromStr for AuthType {
@@ -203,6 +207,7 @@ impl FromStr for AuthType {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
"HadronJWT" => Ok(Self::HadronJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
@@ -213,6 +218,7 @@ impl fmt::Display for AuthType {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
AuthType::HadronJWT => "HadronJWT",
})
}
}
@@ -613,7 +619,10 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
assert!(matches!(
self.auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
));
let (_, jwt_response) = m.split_last().context("protocol violation")?;
@@ -712,7 +721,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
AuthType::NeonJWT | AuthType::HadronJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;

View File

@@ -19,6 +19,7 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
camino.workspace = true
camino-tempfile.workspace = true
chrono.workspace = true
diatomic-waker.workspace = true
git-version.workspace = true
@@ -28,6 +29,7 @@ fail.workspace = true
futures = { workspace = true }
jsonwebtoken.workspace = true
nix = { workspace = true, features = ["ioctl"] }
oid-registry.workspace = true
once_cell.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
@@ -48,9 +50,12 @@ tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
rustls-pemfile.workspace = true
rustls-pki-types.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true
x509-parser.workspace = true
pq_proto.workspace = true
postgres_connection.workspace = true
@@ -67,6 +72,7 @@ camino-tempfile.workspace = true
pprof.workspace = true
serde_assert.workspace = true
tokio = { workspace = true, features = ["test-util"] }
rcgen = { version = "=0.13.1", features = ["crypto", "aws_lc_rs"] }
[[bench]]
name = "benchmarks"

View File

@@ -1,9 +1,9 @@
// For details about authentication see docs/authentication.md
use std::borrow::Cow;
use std::fmt::Display;
use std::fs;
use std::sync::Arc;
use std::{borrow::Cow, io, path::Path};
use anyhow::Result;
use arc_swap::ArcSwap;
@@ -11,14 +11,17 @@ use camino::Utf8Path;
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use oid_registry::OID_PKCS1_RSAENCRYPTION;
use pem::Pem;
use rustls_pki_types::CertificateDer;
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use crate::id::TenantId;
/// Algorithm to use. We require EdDSA.
/// Signature algorithms to use. We allow EdDSA and RSA/SHA-256.
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
const HADRON_STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -95,6 +98,14 @@ impl Claims {
endpoint_id: None,
}
}
pub fn new_for_endpoint(endpoint_id: Uuid) -> Self {
Self {
tenant_id: None,
endpoint_id: Some(endpoint_id),
scope: Scope::TenantEndpoint,
}
}
}
pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
@@ -175,6 +186,96 @@ impl JwtAuth {
Ok(Self::new(decoding_keys))
}
// Helper function to parse a X509 certificate file and extract the RSA public keys from it as `DecodingKey`s.
// - `ceritificate_file_path`: the path to the certificate file. It must be a file, not a directory or anything else.
// Returns the successfully extracted decoding keys. Non-RSA keys and non-X509-parsable certificates are skipped.
// Multuple keys may be returned because a single file can contain multiple certificates.
fn extract_rsa_decoding_keys_from_certificate<P: AsRef<Path>>(
certificate_file_path: P,
) -> Result<Vec<DecodingKey>> {
let certs: io::Result<Vec<CertificateDer<'static>>> = rustls_pemfile::certs(
&mut io::BufReader::new(fs::File::open(certificate_file_path)?),
)
.collect();
Ok(certs?
.iter()
.filter_map(
|cert| match x509_parser::parse_x509_certificate(cert) {
Ok((_, cert)) => {
let public_key = cert.public_key();
// Note that we are just extracting the public key from the certificate, not the signature.
// So the algorithm is just the asymmetric crypto such as RSA, no hashes of or anything like
// that.
if *public_key.algorithm.oid() == OID_PKCS1_RSAENCRYPTION {
Some(DecodingKey::from_rsa_der(&public_key.subject_public_key.data))
} else {
tracing::warn!(
"Unsupported public key algorithm: {:?} found in certificate. Skipping.",
public_key.algorithm
);
None
}
}
Err(e) => {
tracing::warn!("Error parsing certificate: {}. Skipping.", e);
None
}
},
)
.collect())
}
/// Create a `JwtAuth` that can decode tokens using RSA public keys in X509 certificates from the given path.
/// - `cert_path`: the path to a directory or a file containing X509 certificates. If it is a directory, all files
/// under the first level of the directory will be inspected for certificates.
/// Returns the `JwtAuth` with the decoding keys extracted from the certificates, or error.
/// Used by Hadron.
pub fn from_cert_path(cert_path: &Utf8Path) -> Result<Self> {
tracing::info!(
"Loading public keys in certificates from path: {}",
cert_path
);
let mut decoding_keys = Vec::new();
let metadata = cert_path.metadata()?;
if metadata.is_dir() {
for entry in fs::read_dir(cert_path)? {
let path = entry?.path();
if !path.is_file() {
// Ignore directories (don't recurse)
continue;
}
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(path).unwrap_or_default(),
);
}
} else if metadata.is_file() {
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(cert_path).unwrap_or_default(),
);
} else {
anyhow::bail!("{cert_path} is neither a directory or a file")
}
if decoding_keys.is_empty() {
anyhow::bail!(
"Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
);
}
// Note that we need to create a `JwtAuth` with a different `validation` from the default one created by `new()` in this case
// because the `jsonwebtoken` crate requires that all algorithms in `validation.algorithms` belong to the same algorithm family
// (all RSA or all EdDSA).
let mut validation = Validation::default();
validation.algorithms = vec![HADRON_STORAGE_TOKEN_ALGORITHM];
validation.required_spec_claims = [].into();
Ok(Self {
validation,
decoding_keys,
})
}
pub fn from_key(key: String) -> Result<Self> {
Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
}
@@ -217,8 +318,28 @@ pub fn encode_from_key_file<S: Serialize>(claims: &S, pem: &Pem) -> Result<Strin
Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
}
/// Encode (i.e., sign) a Hadron auth token with the given claims and RSA private key. This is used
/// by HCC to sign tokens when deploying compute or returning the compute spec. The resulting token
/// is used by the compute node to authenticate with HCC and PS/SK.
pub fn encode_hadron_token<S: Serialize>(claims: &S, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
encode_hadron_token_with_encoding_key(claims, &key)
}
pub fn encode_hadron_token_with_encoding_key<S: Serialize>(
claims: &S,
encoding_key: &EncodingKey,
) -> Result<String> {
Ok(encode(
&Header::new(HADRON_STORAGE_TOKEN_ALGORITHM),
claims,
encoding_key,
)?)
}
#[cfg(test)]
mod tests {
use io::Write;
use std::str::FromStr;
use super::*;
@@ -243,8 +364,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_decode() {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
scope: Scope::Tenant,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -272,8 +393,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_encode() {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
scope: Scope::Tenant,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
@@ -287,4 +408,72 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
assert_eq!(decoded.claims, claims);
}
#[test]
fn test_decode_with_key_from_certificate() {
// Tests that we can sign (encode) a token with a RSA private key and verify (decode) it with the
// corresponding public key extracted from a certificate.
// Generate two RSA key pairs and create self-signed certificates with it.
let key_pair_1 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let key_pair_2 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let mut params = rcgen::CertificateParams::default();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "eng-brickstore@databricks.com");
let cert_1 = params.clone().self_signed(&key_pair_1).unwrap();
let cert_2 = params.self_signed(&key_pair_2).unwrap();
// Write the certificates and keys to a temporary dir.
let dir = camino_tempfile::tempdir().unwrap();
{
fs::File::create(dir.path().join("cert_1.pem"))
.unwrap()
.write_all(cert_1.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_1.pem"))
.unwrap()
.write_all(key_pair_1.serialize_pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("cert_2.pem"))
.unwrap()
.write_all(cert_2.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_2.pem"))
.unwrap()
.write_all(key_pair_2.serialize_pem().as_bytes())
.unwrap();
}
// Instantiate a `JwtAuth` with the certificate path. The resulting `JwtAuth` should extract the RSA public
// keys out of the X509 certificates and use them as the decoding keys. Since we specified a directory, both
// X509 certificates will be loaded, but the private key files are skipped.
let auth = JwtAuth::from_cert_path(dir.path()).unwrap();
assert_eq!(auth.decoding_keys.len(), 2);
// Also create a `JwtAuth`, specifying a single certificate file for it to get the decoding key from.
let auth_cert_1 = JwtAuth::from_cert_path(&dir.path().join("cert_1.pem")).unwrap();
assert_eq!(auth_cert_1.decoding_keys.len(), 1);
// Encode tokens with some claims.
let claims = Claims {
tenant_id: Some(TenantId::generate()),
endpoint_id: None,
scope: Scope::Tenant,
};
let encoded_1 =
encode_hadron_token(&claims, key_pair_1.serialize_pem().as_bytes()).unwrap();
let encoded_2 =
encode_hadron_token(&claims, key_pair_2.serialize_pem().as_bytes()).unwrap();
// Verify that we can decode the token with matching decoding keys (decoding also verifies the signature).
assert_eq!(auth.decode::<Claims>(&encoded_1).unwrap().claims, claims);
assert_eq!(auth.decode::<Claims>(&encoded_2).unwrap().claims, claims);
assert_eq!(
auth_cert_1.decode::<Claims>(&encoded_1).unwrap().claims,
claims
);
// Verify that the token cannot be decoded with a mismatched decode key.
assert!(auth_cert_1.decode::<Claims>(&encoded_2).is_err());
}
}

View File

@@ -458,25 +458,37 @@ fn start_pageserver(
let http_auth;
let pg_auth;
let grpc_auth;
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) {
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
// unwrap is ok because check is performed when creating config, so path is set and exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
let jwt_auth = JwtAuth::from_key_path(key_path)?;
let use_hadron_jwt = conf.http_auth_type == AuthType::HadronJWT
|| conf.pg_auth_type == AuthType::HadronJWT
|| conf.grpc_auth_type == AuthType::HadronJWT;
let jwt_auth = if use_hadron_jwt {
// To validate Hadron JWTs we need to extract decoding keys from X509 certificates.
JwtAuth::from_cert_path(key_path)?
} else {
JwtAuth::from_key_path(key_path)?
};
let auth: Arc<SwappableJwtAuth> = Arc::new(SwappableJwtAuth::new(jwt_auth));
http_auth = match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
};
pg_auth = match conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
};
grpc_auth = match conf.grpc_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth),
};
} else {
http_auth = None;

View File

@@ -629,6 +629,13 @@ impl PageServerConf {
}
};
let auth_types = [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type];
if auth_types.contains(&AuthType::NeonJWT) && auth_types.contains(&AuthType::HadronJWT) {
return Err(anyhow::anyhow!(
"Mixing neon and hadron style JWT tokens is not supported"
));
}
Ok(conf)
}

View File

@@ -44,6 +44,7 @@ use pageserver_api::models::{
TopTenantShardItem, TopTenantShardsRequest, TopTenantShardsResponse,
};
use pageserver_api::shard::{ShardCount, TenantShardId};
use postgres_backend::AuthType;
use postgres_ffi::PgMajorVersion;
use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError};
use scopeguard::defer;
@@ -55,6 +56,7 @@ use tokio::time::Instant;
use tokio_util::io::StreamReader;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::auth::JwtAuth;
use utils::auth::SwappableJwtAuth;
use utils::generation::Generation;
use utils::id::{TenantId, TimelineId};
@@ -560,6 +562,10 @@ async fn reload_auth_validation_keys_handler(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
// Note to Bricksters: This API returns 400 if HTTP auth is not enabled. This is because `state.auth` is only
// determined by HTTP auth.
// TODO(william.huang): In practice both HTTP and PG auth point to the same SwappableJwtAuth object. Refactor
// this code so that we can swap out the underlying shared auth object even if HTTP auth is None.
check_permission(&request, None)?;
let config = get_config(&request);
let state = get_state(&request);
@@ -570,7 +576,12 @@ async fn reload_auth_validation_keys_handler(
let key_path = config.auth_validation_public_key_path.as_ref().unwrap();
info!("Reloading public key(s) for verifying JWT tokens from {key_path:?}");
match utils::auth::JwtAuth::from_key_path(key_path) {
let new_jwt_auth = if config.http_auth_type == AuthType::HadronJWT {
JwtAuth::from_cert_path(key_path)
} else {
JwtAuth::from_key_path(key_path)
};
match new_jwt_auth {
Ok(new_auth) => {
shared_auth.swap(new_auth);
json_response(StatusCode::OK, ())

View File

@@ -458,7 +458,7 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectError(e) => e.get_error_kind(),
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => match p.as_db_error() {
// user provided a wrong database name

View File

@@ -15,6 +15,7 @@ use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use http_utils::tls_certs::ReloadingCertificateResolver;
use metrics::set_build_info_metric;
use postgres_backend::AuthType;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT,
@@ -109,10 +110,15 @@ struct Args {
/// Listen https endpoint for management and metrics in the form host:port.
#[arg(long, default_value = None)]
listen_https: Option<String>,
/// Advertised endpoint for receiving/sending WAL in the form host:port. If not
/// Advertised endpoint to PS for receiving/sending WAL in the form host:port. If not
/// specified, listen_pg is used to advertise instead.
#[arg(long, default_value = None)]
advertise_pg: Option<String>,
/// Advertised endpoint to compute for receiving/sending WAL in the form host:port.
/// Required if --hcc-base-url is specified.
// TODO(vlad): pull in hcc-base-url too
#[arg(long, default_value = None)]
advertise_pg_tenant_only: Option<String>,
/// Availability zone of the safekeeper.
#[arg(long)]
availability_zone: Option<String>,
@@ -164,6 +170,12 @@ struct Args {
/// WAL backup horizon.
#[arg(long)]
disable_wal_backup: bool,
/// Token authentication type. Allowed values are "NeonJWT" and "HadronJWT". Any specified value only takes effect if
/// --pg-auth-public-key-path, --pg-tenant-only-auth-public-key-path, or --http-auth-public-key-path is specified.
/// NeonJWT: Decoding keys are loaded from plain public key files in the specified key path.
/// HadronJWT: Decoding keys are loaded from X509 certificates in the specified key path.
#[arg(long, verbatim_doc_comment, default_value = "NeonJWT")]
token_auth_type: AuthType,
/// If given, enables auth on incoming connections to WAL service endpoint
/// (--listen-pg). Value specifies path to a .pem public key used for
/// validations of JWT tokens. Empty string is allowed and means disabling
@@ -361,9 +373,19 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg auth JWT key from {path}");
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
}
};
let pg_tenant_only_auth = match args.pg_tenant_only_auth_public_key_path.as_ref() {
@@ -373,9 +395,19 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg tenant only auth JWT key from {path}");
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-tenant-only-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
}
};
let http_auth = match args.http_auth_public_key_path.as_ref() {
@@ -385,7 +417,17 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading http auth JWT key(s) from {path}");
let jwt_auth = JwtAuth::from_key_path(path).context("failed to load the auth key")?;
let jwt_auth = match args.token_auth_type {
AuthType::NeonJWT => {
JwtAuth::from_key_path(path).context("failed to load the auth key")?
}
AuthType::HadronJWT => JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
_ => panic!(
"AuthType {auth_type} is not allowed when --http-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
};
Some(Arc::new(SwappableJwtAuth::new(jwt_auth)))
}
};
@@ -434,6 +476,7 @@ async fn main() -> anyhow::Result<()> {
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
auth_type: args.token_auth_type,
pg_auth,
pg_tenant_only_auth,
http_auth,
@@ -457,7 +500,7 @@ async fn main() -> anyhow::Result<()> {
enable_tls_wal_service_api: args.enable_tls_wal_service_api,
force_metric_collection_on_scrape: args.force_metric_collection_on_scrape,
/* BEGIN_HADRON */
advertise_pg_addr_tenant_only: None,
advertise_pg_addr_tenant_only: args.advertise_pg_tenant_only,
enable_pull_timeline_on_startup: args.enable_pull_timeline_on_startup,
hcc_base_url: None,
global_disk_check_interval: args.global_disk_check_interval,

View File

@@ -1,6 +1,7 @@
#![deny(clippy::undocumented_unsafe_blocks)]
extern crate hyper0 as hyper;
use postgres_backend::AuthType;
use std::time::Duration;
@@ -128,6 +129,7 @@ pub struct SafeKeeperConf {
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub auth_type: AuthType,
pub pg_auth: Option<Arc<JwtAuth>>,
pub pg_tenant_only_auth: Option<Arc<JwtAuth>>,
pub http_auth: Option<Arc<SwappableJwtAuth>>,
@@ -173,6 +175,7 @@ impl SafeKeeperConf {
peer_recovery_enabled: true,
wal_backup_enabled: true,
backup_parallel_jobs: 1,
auth_type: AuthType::HadronJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -612,19 +612,25 @@ pub async fn handle_request(
}
}
let max_term = statuses
.iter()
.map(|(status, _)| status.acceptor_state.term)
.max()
.unwrap();
// Find the most advanced safekeeper
let (status, i) = statuses
.into_iter()
.max_by_key(|(status, _)| {
(
status.acceptor_state.epoch,
status.flush_lsn,
/* BEGIN_HADRON */
// We need to pull from the SK with the highest term.
// This is because another compute may come online and vote the same highest term again on the other two SKs.
// Then, there will be 2 computes running on the same term.
status.acceptor_state.term,
/* END_HADRON */
status.flush_lsn,
status.commit_lsn,
)
})
@@ -634,6 +640,22 @@ pub async fn handle_request(
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);
// TODO(diko): This is hadron only check to make sure that we pull the timeline
// from the safekeeper with the highest term during timeline restore.
// We could avoid returning the error by calling bump_term after pull_timeline.
// However, this is not a big deal because we retry the pull_timeline requests.
// The check should be removed together with removing custom hadron logic for
// safekeeper restore.
if wait_for_peer_timeline_status && status.acceptor_state.term != max_term {
return Err(ApiError::PreconditionFailed(
format!(
"choosen safekeeper {} has term {}, but the most advanced term is {}",
safekeeper_host, status.acceptor_state.term, max_term
)
.into(),
));
}
match pull_timeline(
status,
safekeeper_host,

View File

@@ -195,12 +195,14 @@ impl StateSK {
to: Configuration,
) -> Result<TimelineMembershipSwitchResponse> {
let result = self.state_mut().membership_switch(to).await?;
let flush_lsn = self.flush_lsn();
let last_log_term = self.state().acceptor_state.get_last_log_term(flush_lsn);
Ok(TimelineMembershipSwitchResponse {
previous_conf: result.previous_conf,
current_conf: result.current_conf,
last_log_term: self.state().acceptor_state.term,
flush_lsn: self.flush_lsn(),
last_log_term,
flush_lsn,
})
}

View File

@@ -103,7 +103,7 @@ async fn handle_socket(
};
let auth_type = match auth_key {
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
Some(_) => conf.auth_type,
};
let auth_pair = auth_key.map(|key| (allowed_auth_scope, key));
let mut conn_handler = SafekeeperPostgresHandler::new(

View File

@@ -14,6 +14,7 @@ use desim::network::TCP;
use desim::node_os::NodeOs;
use desim::proto::{AnyMessage, NetEvent, NodeEvent};
use http::Uri;
use postgres_backend::AuthType;
use safekeeper::SafeKeeperConf;
use safekeeper::safekeeper::{
ProposerAcceptorMessage, SK_PROTO_VERSION_3, SafeKeeper, UNKNOWN_SERVER_VERSION,
@@ -169,6 +170,7 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
availability_zone: None,
peer_recovery_enabled: false,
backup_parallel_jobs: 0,
auth_type: AuthType::NeonJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -31,6 +31,7 @@ humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
json-structural-diff.workspace = true
jsonwebtoken.workspace = true
lasso.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
@@ -74,4 +75,4 @@ http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -9,7 +9,6 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
Ok(())
}
#[allow(dead_code)]
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
if claims.scope != Scope::TenantEndpoint {
return Err(AuthError("Scope mismatch. Permission denied".into()));

View File

@@ -0,0 +1,52 @@
use anyhow::{Result, bail};
use camino::Utf8Path;
use jsonwebtoken::EncodingKey;
use std::fs;
use utils::{
auth::{Claims, Scope, encode_hadron_token_with_encoding_key},
id::TenantId,
};
use uuid::Uuid;
pub struct HadronTokenGenerator {
encoding_key: EncodingKey,
}
impl HadronTokenGenerator {
pub fn new(path: &Utf8Path) -> anyhow::Result<Self> {
let key_data = match fs::read(path) {
Ok(ok) => ok,
Err(e) => bail!("Error reading private key file {path:?}. Error: {e}"),
};
let encoding_key = match EncodingKey::from_rsa_pem(&key_data) {
Ok(ok) => ok,
Err(e) => {
bail!("Error reading private key file {path:?} as RSA private key. Error: {e}")
}
};
Ok(Self { encoding_key })
}
pub fn generate_tenant_scope_token(&self, tenant_id: TenantId) -> Result<String> {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
self.internal_encode_token(&claims)
}
pub fn generate_tenant_endpoint_scope_token(&self, endpoint_id: Uuid) -> Result<String> {
let claims = Claims::new_for_endpoint(endpoint_id);
self.internal_encode_token(&claims)
}
pub fn generate_ps_sk_auth_token(&self) -> Result<String> {
let claims = Claims {
tenant_id: None,
endpoint_id: None,
scope: Scope::SafekeeperData,
};
self.internal_encode_token(&claims)
}
fn internal_encode_token(&self, claims: &Claims) -> Result<String> {
encode_hadron_token_with_encoding_key(claims, &self.encoding_key)
}
}

View File

@@ -40,6 +40,7 @@ use tokio_util::sync::CancellationToken;
use tracing::warn;
use utils::auth::{Scope, SwappableJwtAuth};
use utils::id::{NodeId, TenantId, TimelineId};
use uuid::Uuid;
use crate::http;
use crate::metrics::{
@@ -1801,6 +1802,23 @@ fn check_permissions(request: &Request<Body>, required_scope: Scope) -> Result<(
}
})
}
/// Similar to `check_permissions()` above, but checks for TenantEndpoint scope specifically. Used by the compute spec-fetch API.
/// Access by Admin-scope tokens is also permitted.
/// TODO(william.huang): Merge with the previous function by refactoring `Scope` to make it carry the dependent arguments.
/// E.g., `Scope::TenantEndpoint(EndpointId)`, `Scope::Tenant(TenantId)`, etc.
#[allow(unused)]
fn check_endpoint_permission(request: &Request<Body>, endpoint_id: Uuid) -> Result<(), ApiError> {
check_permission_with(
request,
|claims| match crate::auth::check_endpoint_permission(claims, endpoint_id) {
Err(e) => match crate::auth::check_permission(claims, Scope::Admin) {
Ok(()) => Ok(()),
Err(_) => Err(e),
},
Ok(()) => Ok(()),
},
)
}
#[derive(Clone, Debug)]
struct RequestMeta {

View File

@@ -6,6 +6,7 @@ extern crate hyper0 as hyper;
mod auth;
mod background_node_operations;
mod compute_hook;
pub mod hadron_token;
pub mod hadron_utils;
mod heartbeater;
pub mod http;

View File

@@ -14,6 +14,7 @@ use metrics::BuildInfo;
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::hadron_token::HadronTokenGenerator;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::Persistence;
@@ -70,10 +71,26 @@ struct Cli {
#[arg(long)]
listen_https: Option<std::net::SocketAddr>,
/// Public key for JWT authentication of clients
/// PEM-encoded public key string for JWT authentication of clients.
#[arg(long)]
public_key: Option<String>,
/// Path to public key certificates used for JWT authentiation of clients.
/// Only one of `public_key` and `public_key_cert_path` should be set.
/// `public_key` or `public_key_cert_path` can point to either a file or a directory.
/// When pointed to a directory, public keys in all files in the first level of
/// the directory (i.e., no subdirectories) will be loaded.
#[arg(long)]
public_key_cert_path: Option<Utf8PathBuf>,
/// Path to the file containing the private key used to generate JWTs for client
/// authentication. The file should contain a single PEM-encoded private key.
/// The HCC uses this key to sign JWTs handed out to other components.
/// Note that unlike the `public_key` and `public_key_cert_path` args above,
/// `private_key_path` must specify a file path, not a directory.
#[arg(long)]
private_key_path: Option<Utf8PathBuf>,
/// Token for authenticating this service with the pageservers it controls
#[arg(long)]
jwt_token: Option<String>,
@@ -256,6 +273,7 @@ struct Secrets {
safekeeper_jwt_token: Option<String>,
control_plane_jwt_token: Option<String>,
peer_jwt_token: Option<String>,
token_generator: Option<HadronTokenGenerator>,
}
const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG";
@@ -281,7 +299,16 @@ impl Secrets {
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV) {
Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?),
None => None,
None => {
if let Some(path) = args.public_key_cert_path.as_ref() {
Some(
JwtAuth::from_cert_path(path)
.context("Loading public key from certificates")?,
)
} else {
None
}
}
};
let this = Self {
@@ -300,6 +327,11 @@ impl Secrets {
Self::CONTROL_PLANE_JWT_TOKEN_ENV,
),
peer_jwt_token: Self::load_secret(&args.peer_jwt_token, Self::PEER_JWT_TOKEN_ENV),
token_generator: args
.private_key_path
.as_ref()
.map(|path| HadronTokenGenerator::new(path))
.transpose()?,
};
Ok(this)
@@ -489,12 +521,12 @@ async fn async_main() -> anyhow::Result<()> {
let persistence = Arc::new(Persistence::new(secrets.database_url).await);
let service = Service::spawn(config, persistence.clone()).await?;
let service = Service::spawn(config, persistence.clone(), secrets.token_generator).await?;
let auth = secrets
let jwt_auth = secrets
.public_key
.map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
let router = make_router(service.clone(), auth, build_info)
let router = make_router(service.clone(), jwt_auth, build_info)
.build()
.map_err(|err| anyhow!(err))?;
let http_service =

View File

@@ -4,6 +4,7 @@ pub(crate) mod safekeeper_reconciler;
mod safekeeper_service;
mod tenant_shard_iterator;
use crate::hadron_token::HadronTokenGenerator;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
@@ -518,6 +519,11 @@ pub struct Service {
inner: Arc<std::sync::RwLock<ServiceState>>,
config: Config,
persistence: Arc<Persistence>,
// HadronTokenGenerator to generate (sign) JWTs during compute deployment and compute-spec generation.
#[allow(unused)]
token_generator: Option<HadronTokenGenerator>,
compute_hook: Arc<ComputeHook>,
result_tx: tokio::sync::mpsc::UnboundedSender<ReconcileResultRequest>,
@@ -1668,7 +1674,11 @@ impl Service {
}
}
pub async fn spawn(config: Config, persistence: Arc<Persistence>) -> anyhow::Result<Arc<Self>> {
pub async fn spawn(
config: Config,
persistence: Arc<Persistence>,
token_generator: Option<HadronTokenGenerator>,
) -> anyhow::Result<Arc<Self>> {
let (result_tx, result_rx) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = tokio::sync::mpsc::unbounded_channel();
@@ -1925,6 +1935,7 @@ impl Service {
))),
config: config.clone(),
persistence,
token_generator,
compute_hook: Arc::new(ComputeHook::new(config.clone())?),
result_tx,
heartbeater_ps,

View File

@@ -24,12 +24,12 @@ use pageserver_api::controller_api::{
};
use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use safekeeper_api::PgVersionId;
use safekeeper_api::Term;
use safekeeper_api::membership::{self, MemberSet, SafekeeperGeneration};
use safekeeper_api::models::{
PullTimelineRequest, TimelineLocateResponse, TimelineMembershipSwitchRequest,
TimelineMembershipSwitchResponse,
};
use safekeeper_api::{INITIAL_TERM, Term};
use safekeeper_client::mgmt_api;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -1298,13 +1298,7 @@ impl Service {
)
.await?;
let mut sync_position = (INITIAL_TERM, Lsn::INVALID);
for res in results.into_iter().flatten() {
let sk_position = (res.last_log_term, res.flush_lsn);
if sync_position < sk_position {
sync_position = sk_position;
}
}
let sync_position = Self::get_sync_position(&results)?;
tracing::info!(
%generation,
@@ -1598,4 +1592,36 @@ impl Service {
Ok(())
}
/// Get membership switch responses from all safekeepers and return the sync position.
///
/// Sync position is a position equal or greater than the commit position.
/// It is guaranteed that all WAL entries with (last_log_term, flush_lsn)
/// greater than the sync position are not committed (= not on a quorum).
///
/// Returns error if there is no quorum of successful responses.
fn get_sync_position(
responses: &[mgmt_api::Result<TimelineMembershipSwitchResponse>],
) -> Result<(Term, Lsn), ApiError> {
let quorum_size = responses.len() / 2 + 1;
let mut wal_positions = responses
.iter()
.flatten()
.map(|res| (res.last_log_term, res.flush_lsn))
.collect::<Vec<_>>();
// Should be already checked if the responses are from tenant_timeline_set_membership_quorum.
if wal_positions.len() < quorum_size {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"not enough successful responses to get sync position: {}/{}",
wal_positions.len(),
quorum_size,
)));
}
wal_positions.sort();
Ok(wal_positions[quorum_size - 1])
}
}

View File

@@ -13,10 +13,11 @@ if TYPE_CHECKING:
@dataclass
class AuthKeys:
priv: str
algorithm: str
def generate_token(self, *, scope: TokenScope, **token_data: Any) -> str:
token_data = {key: str(val) for key, val in token_data.items()}
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm="EdDSA")
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm=self.algorithm)
# cast(Any, self.priv)
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
@@ -46,3 +47,4 @@ class TokenScope(StrEnum):
TENANT = "tenant"
SCRUBBER = "scrubber"
INFRA = "infra"
TENANT_ENDPOINT = "tenantendpoint"

View File

@@ -78,20 +78,26 @@ class EndpointHttpClient(requests.Session):
json: dict[str, str] = res.json()
return json
def prewarm_lfc(self, from_endpoint_id: str | None = None):
def prewarm_lfc(self, from_endpoint_id: str | None = None) -> dict[str, str]:
"""
Prewarm LFC cache from given endpoint and wait till it finishes or errors
"""
params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict()
self.post(self.prewarm_url, params=params).raise_for_status()
self.prewarm_lfc_wait()
return self.prewarm_lfc_wait()
def prewarm_lfc_wait(self):
def cancel_prewarm_lfc(self):
"""
Cancel LFC prewarm if any is ongoing
"""
self.delete(self.prewarm_url).raise_for_status()
def prewarm_lfc_wait(self) -> dict[str, str]:
"""
Wait till LFC prewarm returns with error or success.
If prewarm was not requested before calling this function, it will error
"""
statuses = "failed", "completed", "skipped"
statuses = "failed", "completed", "skipped", "cancelled"
def prewarmed():
json = self.prewarm_lfc_status()
@@ -101,6 +107,7 @@ class EndpointHttpClient(requests.Session):
wait_until(prewarmed, timeout=60)
res = self.prewarm_lfc_status()
assert res["status"] != "failed", res
return res
def offload_lfc_status(self) -> dict[str, str]:
res = self.get(self.offload_url)
@@ -108,29 +115,31 @@ class EndpointHttpClient(requests.Session):
json: dict[str, str] = res.json()
return json
def offload_lfc(self):
def offload_lfc(self) -> dict[str, str]:
"""
Offload LFC cache to endpoint storage and wait till offload finishes or errors
"""
self.post(self.offload_url).raise_for_status()
self.offload_lfc_wait()
return self.offload_lfc_wait()
def offload_lfc_wait(self):
def offload_lfc_wait(self) -> dict[str, str]:
"""
Wait till LFC offload returns with error or success.
If offload was not requested before calling this function, it will error
"""
statuses = "failed", "completed", "skipped"
def offloaded():
json = self.offload_lfc_status()
status, err = json["status"], json.get("error")
assert status in ["failed", "completed"], f"{status}, {err=}"
assert status in statuses, f"{status}, {err=}"
wait_until(offloaded, timeout=60)
res = self.offload_lfc_status()
assert res["status"] != "failed", res
return res
def promote(self, promote_spec: dict[str, Any], disconnect: bool = False):
def promote(self, promote_spec: dict[str, Any], disconnect: bool = False) -> dict[str, str]:
url = f"http://localhost:{self.external_port}/promote"
if disconnect:
try: # send first request to start promote and disconnect

View File

@@ -28,11 +28,15 @@ import asyncpg
import backoff
import boto3
import httpx
import jwt
import psycopg2
import psycopg2.sql
import pytest
import requests
import toml
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from jwcrypto import jwk
# Type-related stuff
@@ -262,7 +266,6 @@ class PgProtocol:
# pooler does not support statement_timeout
# Check if the hostname contains the string 'pooler'
hostname = result.get("host", "")
log.info(f"Hostname: {hostname}")
options = result.get("options", "")
if "statement_timeout" not in options and "pooler" not in hostname:
options = f"-cstatement_timeout=120s {options}"
@@ -403,6 +406,15 @@ class PageserverImportConfig:
return ("timeline_import_config", value)
@dataclass
class HadronTokenDecoder:
public_key: str
algorithm: str
def decode_token(self, token: str) -> dict[str, Any]:
return jwt.decode(token, self.public_key, algorithms=[self.algorithm])
class NeonEnvBuilder:
"""
Builder object to create a Neon runtime environment
@@ -473,6 +485,7 @@ class NeonEnvBuilder:
self.safekeepers_id_start = safekeepers_id_start
self.safekeepers_enable_fsync = safekeepers_enable_fsync
self.auth_enabled = auth_enabled
self.use_hadron_auth_tokens = False
self.default_branch_name = default_branch_name
self.env: NeonEnv | None = None
self.keep_remote_storage_contents: bool = True
@@ -1122,6 +1135,11 @@ class NeonEnv:
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None
)
# The auth token type used in the test environment. neon_local is instruted to generate key pairs
# according to the auth token type. The keys are always generated but are only used if
# config.auth_enabled == True.
self.auth_token_type: str = "HadronJWT" if config.use_hadron_auth_tokens else "NeonJWT"
neon_local_env_vars = {}
if self.rust_log_override is not None:
neon_local_env_vars["RUST_LOG"] = self.rust_log_override
@@ -1199,6 +1217,7 @@ class NeonEnv:
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
"auth_token_type": self.auth_token_type,
}
if config.use_https_storage_broker_api:
@@ -1246,9 +1265,9 @@ class NeonEnv:
)
# Create config for pageserver
http_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
http_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
pg_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
grpc_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
for ps_id in range(
self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers
):
@@ -1386,9 +1405,8 @@ class NeonEnv:
"https_port": port.https,
"sync": config.safekeepers_enable_fsync,
"use_https_safekeeper_api": config.use_https_safekeeper_api,
"auth_type": self.auth_token_type if config.auth_enabled else "Trust",
}
if config.auth_enabled:
sk_cfg["auth_enabled"] = True
if self.safekeepers_remote_storage is not None:
sk_cfg["remote_storage"] = (
self.safekeepers_remote_storage.to_toml_inline_table().strip()
@@ -1579,29 +1597,66 @@ class NeonEnv:
@cached_property
def auth_keys(self) -> AuthKeys:
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text()
return AuthKeys(priv=priv)
algorithm = "EdDSA" if self.auth_token_type == "NeonJWT" else "RS256"
return AuthKeys(priv=priv, algorithm=algorithm)
@cached_property
def hadron_token_decoder(self) -> HadronTokenDecoder:
cert = (Path(self.repo_dir) / "auth_public_key.pem").read_text()
x509_cert = x509.load_pem_x509_certificate(cert.encode(), default_backend())
pem_public_key = (
x509_cert.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode()
)
return HadronTokenDecoder(public_key=pem_public_key, algorithm="RS256")
def regenerate_keys_at(self, privkey_path: Path, pubkey_path: Path):
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
if self.auth_token_type == "NeonJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
elif self.auth_token_type == "HadronJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-sha256",
"-keyout",
privkey_path,
"-out",
pubkey_path,
"-nodes",
"-subj",
"/CN=eng-brickstore@databricks.com",
],
cwd=self.repo_dir,
check=True,
)
del self.auth_keys
def generate_endpoint_id(self) -> str:
@@ -2022,10 +2077,10 @@ class NeonStorageController(MetricsGetter, LogUtils):
return resp
def headers(self, scope: TokenScope | None) -> dict[str, str]:
def headers(self, scope: TokenScope | None, **token_data: Any) -> dict[str, str]:
headers = {}
if self.auth_enabled and scope is not None:
jwt_token = self.env.auth_keys.generate_token(scope=scope)
jwt_token = self.env.auth_keys.generate_token(scope=scope, **token_data)
headers["Authorization"] = f"Bearer {jwt_token}"
return headers
@@ -2314,6 +2369,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
timeline_id: TimelineId,
new_sk_set: list[int],
):
log.info(f"migrate_safekeepers({tenant_id}, {timeline_id}, {new_sk_set})")
response = self.request(
"POST",
f"{self.api}/v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate",

View File

@@ -32,8 +32,11 @@ def assert_client_not_authorized(env: NeonEnv, http_client: PageserverHttpClient
assert_client_authorized(env, http_client)
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
ps = env.pageserver
@@ -72,8 +75,10 @@ def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
env.pageserver.tenant_create(TenantId.generate(), auth_token=tenant_token)
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
@@ -91,8 +96,10 @@ def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
assert cur.fetchone() == (5000050000,)
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -145,8 +152,10 @@ def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
assert_client_authorized(env, pageserver_http_client_new)
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -183,7 +192,12 @@ def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("auth_enabled", [False, True])
def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_auth_failures(
neon_env_builder: NeonEnvBuilder, auth_enabled: bool, use_hadron_auth_tokens: bool
):
neon_env_builder.auth_enabled = auth_enabled
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.auth_enabled = auth_enabled
env = neon_env_builder.init_start()

View File

@@ -863,7 +863,6 @@ def test_pageserver_compaction_circuit_breaker(neon_env_builder: NeonEnvBuilder)
assert not env.pageserver.log_contains(".*Circuit breaker failure ended.*")
@pytest.mark.skip(reason="Lakebase mode")
def test_ps_corruption_detection_feedback(neon_env_builder: NeonEnvBuilder):
"""
Test that when the pageserver detects corruption during image layer creation,
@@ -890,7 +889,9 @@ def test_ps_corruption_detection_feedback(neon_env_builder: NeonEnvBuilder):
timeline_id = env.initial_timeline
pageserver_http = env.pageserver.http_client()
workload = Workload(env, tenant_id, timeline_id)
workload = Workload(
env, tenant_id, timeline_id, endpoint_opts={"config_lines": ["neon.lakebase_mode=true"]}
)
workload.init()
# Enable the failpoint that will cause image layer creation to fail due to a (simulated) detected

View File

@@ -1,6 +1,6 @@
import random
import threading
from enum import StrEnum
from threading import Thread
from time import sleep
from typing import Any
@@ -47,19 +47,23 @@ def offload_lfc(method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor)
# With autoprewarm, we need to be sure LFC was offloaded after all writes
# finish, so we sleep. Otherwise we'll have less prewarmed pages than we want
sleep(AUTOOFFLOAD_INTERVAL_SECS)
client.offload_lfc_wait()
return
offload_res = client.offload_lfc_wait()
log.info(offload_res)
return offload_res
if method == PrewarmMethod.COMPUTE_CTL:
status = client.prewarm_lfc_status()
assert status["status"] == "not_prewarmed"
assert "error" not in status
client.offload_lfc()
offload_res = client.offload_lfc()
log.info(offload_res)
assert client.prewarm_lfc_status()["status"] == "not_prewarmed"
parsed = prom_parse(client)
desired = {OFFLOAD_LABEL: 1, PREWARM_LABEL: 0, OFFLOAD_ERR_LABEL: 0, PREWARM_ERR_LABEL: 0}
assert parsed == desired, f"{parsed=} != {desired=}"
return
return offload_res
raise AssertionError(f"{method} not in PrewarmMethod")
@@ -68,21 +72,30 @@ def prewarm_endpoint(
method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor, lfc_state: str | None
):
if method == PrewarmMethod.AUTOPREWARM:
client.prewarm_lfc_wait()
prewarm_res = client.prewarm_lfc_wait()
log.info(prewarm_res)
elif method == PrewarmMethod.COMPUTE_CTL:
client.prewarm_lfc()
prewarm_res = client.prewarm_lfc()
log.info(prewarm_res)
return prewarm_res
elif method == PrewarmMethod.POSTGRES:
cur.execute("select neon.prewarm_local_cache(%s)", (lfc_state,))
def check_prewarmed(
def check_prewarmed_contains(
method: PrewarmMethod, client: EndpointHttpClient, desired_status: dict[str, str | int]
):
if method == PrewarmMethod.AUTOPREWARM:
assert client.prewarm_lfc_status() == desired_status
prewarm_status = client.prewarm_lfc_status()
for k in desired_status:
assert desired_status[k] == prewarm_status[k]
assert prom_parse(client)[PREWARM_LABEL] == 1
elif method == PrewarmMethod.COMPUTE_CTL:
assert client.prewarm_lfc_status() == desired_status
prewarm_status = client.prewarm_lfc_status()
for k in desired_status:
assert desired_status[k] == prewarm_status[k]
desired = {OFFLOAD_LABEL: 0, PREWARM_LABEL: 1, PREWARM_ERR_LABEL: 0, OFFLOAD_ERR_LABEL: 0}
assert prom_parse(client) == desired
@@ -149,9 +162,6 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod):
log.info(f"Used LFC size: {lfc_used_pages}")
pg_cur.execute("select * from neon.get_prewarm_info()")
total, prewarmed, skipped, _ = pg_cur.fetchall()[0]
log.info(f"Prewarm info: {total=} {prewarmed=} {skipped=}")
progress = (prewarmed + skipped) * 100 // total
log.info(f"Prewarm progress: {progress}%")
assert lfc_used_pages > 10000
assert total > 0
assert prewarmed > 0
@@ -161,7 +171,54 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod):
assert lfc_cur.fetchall()[0][0] == n_records * (n_records + 1) / 2
desired = {"status": "completed", "total": total, "prewarmed": prewarmed, "skipped": skipped}
check_prewarmed(method, client, desired)
check_prewarmed_contains(method, client, desired)
@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping")
def test_lfc_prewarm_cancel(neon_simple_env: NeonEnv):
"""
Test we can cancel LFC prewarm and prewarm successfully after
"""
env = neon_simple_env
n_records = 1000000
cfg = [
"autovacuum = off",
"shared_buffers=1MB",
"neon.max_file_cache_size=1GB",
"neon.file_cache_size_limit=1GB",
"neon.file_cache_prewarm_limit=1000",
]
endpoint = env.endpoints.create_start(branch_name="main", config_lines=cfg)
pg_conn = endpoint.connect()
pg_cur = pg_conn.cursor()
pg_cur.execute("create schema neon; create extension neon with schema neon")
pg_cur.execute("create database lfc")
lfc_conn = endpoint.connect(dbname="lfc")
lfc_cur = lfc_conn.cursor()
log.info(f"Inserting {n_records} rows")
lfc_cur.execute("create table t(pk integer primary key, payload text default repeat('?', 128))")
lfc_cur.execute(f"insert into t (pk) values (generate_series(1,{n_records}))")
log.info(f"Inserted {n_records} rows")
client = endpoint.http_client()
method = PrewarmMethod.COMPUTE_CTL
offload_lfc(method, client, pg_cur)
endpoint.stop()
endpoint.start()
thread = Thread(target=lambda: prewarm_endpoint(method, client, pg_cur, None))
thread.start()
# wait 2 seconds to ensure we cancel prewarm SQL query
sleep(2)
client.cancel_prewarm_lfc()
thread.join()
assert client.prewarm_lfc_status()["status"] == "cancelled"
prewarm_endpoint(method, client, pg_cur, None)
assert client.prewarm_lfc_status()["status"] == "completed"
@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping")
@@ -178,9 +235,8 @@ def test_lfc_prewarm_empty(neon_simple_env: NeonEnv):
cur = conn.cursor()
cur.execute("create schema neon; create extension neon with schema neon")
method = PrewarmMethod.COMPUTE_CTL
offload_lfc(method, client, cur)
prewarm_endpoint(method, client, cur, None)
assert client.prewarm_lfc_status()["status"] == "skipped"
assert offload_lfc(method, client, cur)["status"] == "skipped"
assert prewarm_endpoint(method, client, cur, None)["status"] == "skipped"
# autoprewarm isn't needed as we prewarm manually
@@ -251,11 +307,11 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, method: PrewarmMet
workload_threads = []
for _ in range(n_threads):
t = threading.Thread(target=workload)
t = Thread(target=workload)
workload_threads.append(t)
t.start()
prewarm_thread = threading.Thread(target=prewarm)
prewarm_thread = Thread(target=prewarm)
prewarm_thread.start()
def prewarmed():

View File

@@ -286,3 +286,177 @@ def test_sk_generation_aware_tombstones(neon_env_builder: NeonEnvBuilder):
assert re.match(r".*Timeline .* deleted.*", exc.value.response.text)
# The timeline should remain deleted.
expect_deleted(second_sk)
def test_safekeeper_migration_stale_timeline(neon_env_builder: NeonEnvBuilder):
"""
Test that safekeeper migration handles stale timeline correctly by migrating to
a safekeeper with a stale timeline.
1. Check that we are waiting for the stale timeline to catch up with the commit lsn.
The migration might fail if there is no compute to advance the WAL.
2. Check that we rely on last_log_term (and not the current term) when waiting for the
sync_position on step 7.
3. Check that migration succeeds if the compute is running.
"""
neon_env_builder.num_safekeepers = 2
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 1,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
env.storage_controller.allowed_errors.append(".*not enough successful .* to reach quorum.*")
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
active_sk = env.get_safekeeper(mconf["sk_set"][0])
other_sk = [sk for sk in env.safekeepers if sk.id != active_sk.id][0]
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
ep.start(safekeeper_generation=1, safekeepers=[active_sk.id])
ep.safe_psql("CREATE TABLE t(a int)")
ep.safe_psql("INSERT INTO t VALUES (0)")
# Pull the timeline to other_sk, so other_sk now has a "stale" timeline on it.
other_sk.pull_timeline([active_sk], env.initial_tenant, env.initial_timeline)
# Advance the WAL on active_sk.
ep.safe_psql("INSERT INTO t VALUES (1)")
# The test is more tricky if we have the same last_log_term but different term/flush_lsn.
# Stop the active_sk during the endpoint shutdown because otherwise compute_ctl runs
# sync_safekeepers and advances last_log_term on active_sk.
active_sk.stop()
ep.stop(mode="immediate")
active_sk.start()
active_sk_status = active_sk.http_client().timeline_status(
env.initial_tenant, env.initial_timeline
)
other_sk_status = other_sk.http_client().timeline_status(
env.initial_tenant, env.initial_timeline
)
# other_sk should have the same last_log_term, but a stale flush_lsn.
assert active_sk_status.last_log_term == other_sk_status.last_log_term
assert active_sk_status.flush_lsn > other_sk_status.flush_lsn
commit_lsn = active_sk_status.flush_lsn
# Bump the term on other_sk to make it higher than active_sk.
# This is to make sure we don't use current term instead of last_log_term in the algorithm.
other_sk.http_client().term_bump(
env.initial_tenant, env.initial_timeline, active_sk_status.term + 100
)
# TODO(diko): now it fails because the timeline on other_sk is stale and there is no compute
# to catch up it with active_sk. It might be fixed in https://databricks.atlassian.net/browse/LKB-946
# if we delete stale timelines before starting the migration.
# But the rest of the test is still valid: we should not lose committed WAL after the migration.
with pytest.raises(
StorageControllerApiException, match="not enough successful .* to reach quorum"
):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [other_sk.id]
)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["new_sk_set"] == [other_sk.id]
assert mconf["sk_set"] == [active_sk.id]
assert mconf["generation"] == 2
# Start the endpoint, so it advances the WAL on other_sk.
ep.start(safekeeper_generation=2, safekeepers=[active_sk.id, other_sk.id])
# Now the migration should succeed.
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [other_sk.id]
)
# Check that we didn't lose committed WAL.
assert (
other_sk.http_client().timeline_status(env.initial_tenant, env.initial_timeline).flush_lsn
>= commit_lsn
)
assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)]
def test_pull_from_most_advanced_sk(neon_env_builder: NeonEnvBuilder):
"""
Test that we pull the timeline from the most advanced safekeeper during the
migration and do not lose committed WAL.
"""
neon_env_builder.num_safekeepers = 4
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 3,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
sk_set = mconf["sk_set"]
assert len(sk_set) == 3
other_sk = [sk.id for sk in env.safekeepers if sk.id not in sk_set][0]
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
ep.start(safekeeper_generation=1, safekeepers=sk_set)
ep.safe_psql("CREATE TABLE t(a int)")
ep.safe_psql("INSERT INTO t VALUES (0)")
# Stop one sk, so we have a lagging WAL on it.
env.get_safekeeper(sk_set[0]).stop()
# Advance the WAL on the other sks.
ep.safe_psql("INSERT INTO t VALUES (1)")
# Stop other sks to make sure compute_ctl doesn't advance the last_log_term on them during shutdown.
for sk_id in sk_set[1:]:
env.get_safekeeper(sk_id).stop()
ep.stop(mode="immediate")
for sk_id in sk_set:
env.get_safekeeper(sk_id).start()
# Bump the term on the lagging sk to make sure we don't use it to choose the most advanced sk.
env.get_safekeeper(sk_set[0]).http_client().term_bump(
env.initial_tenant, env.initial_timeline, 100
)
def get_commit_lsn(sk_set: list[int]):
flush_lsns = []
last_log_terms = []
for sk_id in sk_set:
sk = env.get_safekeeper(sk_id)
status = sk.http_client().timeline_status(env.initial_tenant, env.initial_timeline)
flush_lsns.append(status.flush_lsn)
last_log_terms.append(status.last_log_term)
# In this test we assume that all sks have the same last_log_term.
assert len(set(last_log_terms)) == 1
flush_lsns.sort(reverse=True)
commit_lsn = flush_lsns[len(sk_set) // 2]
log.info(f"sk_set: {sk_set}, flush_lsns: {flush_lsns}, commit_lsn: {commit_lsn}")
return commit_lsn
commit_lsn_before_migration = get_commit_lsn(sk_set)
# Make two migrations, so the lagging sk stays in the sk_set, but other sks are replaced.
new_sk_set1 = [sk_set[0], sk_set[1], other_sk] # remove sk_set[2], add other_sk
new_sk_set2 = [sk_set[0], other_sk, sk_set[2]] # remove sk_set[1], add sk_set[2] back
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, new_sk_set1
)
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, new_sk_set2
)
commit_lsn_after_migration = get_commit_lsn(new_sk_set2)
# We should not lose committed WAL.
# If we have choosen the lagging sk to pull the timeline from, this might fail.
assert commit_lsn_before_migration <= commit_lsn_after_migration
ep.start(safekeeper_generation=5, safekeepers=new_sk_set2)
assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)]

View File

@@ -1406,6 +1406,9 @@ def test_storage_controller_s3_time_travel_recovery(
def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
env = neon_env_builder.init_start()
assert env.auth_token_type == "NeonJWT"
svc = env.storage_controller
api = env.storage_controller_api

View File

@@ -2742,7 +2742,6 @@ def test_pull_timeline_partial_segment_integrity(neon_env_builder: NeonEnvBuilde
wait_until(unevicted)
@pytest.mark.skip(reason="Lakebase mode")
def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
"""
Test that the timeline disk usage circuit breaker works as expected. We test that:
@@ -2762,7 +2761,12 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
# Create a timeline and endpoint
env.create_branch("test_timeline_disk_usage_limit")
endpoint = env.endpoints.create_start("test_timeline_disk_usage_limit")
endpoint = env.endpoints.create_start(
"test_timeline_disk_usage_limit",
config_lines=[
"neon.lakebase_mode=true",
],
)
# Install the neon extension in the test database. We need it to query perf counter metrics.
with closing(endpoint.connect()) as conn:

View File

@@ -1,18 +1,18 @@
{
"v17": [
"17.5",
"fa1788475e3146cc9c7c6a1b74f48fd296898fcd"
"1e01fcea2a6b38180021aa83e0051d95286d9096"
],
"v16": [
"16.9",
"9b9cb4b3e33347aea8f61e606bb6569979516de5"
"a42351fcd41ea01edede1daed65f651e838988fc"
],
"v15": [
"15.13",
"aaaeff2550d5deba58847f112af9b98fa3a58b00"
"2aaab3bb4a13557aae05bb2ae0ef0a132d0c4f85"
],
"v14": [
"14.18",
"c9f9fdd0113b52c0bd535afdb09d3a543aeee25f"
"2155cb165d05f617eb2c8ad7e43367189b627703"
]
}

View File

@@ -78,6 +78,7 @@ parquet = { version = "53", default-features = false, features = ["zstd"] }
portable-atomic = { version = "1", features = ["require-cas"] }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.9" }
rcgen = { version = "0.13", features = ["aws_lc_rs"] }
regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" }
@@ -126,6 +127,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
displaydoc = { version = "0.2" }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
half = { version = "2", default-features = false, features = ["num-traits"] }