Compare commits

..

5 Commits

Author SHA1 Message Date
Alexey Kondratov
70a273081b fix(compute_ctl): Use coalesce without schema as it's not a catalog func 2025-07-31 18:48:51 +02:00
Dmitrii Kovalkov
312a74f11f storcon: implement safekeeper_migrate_abort handler (#12705)
## Problem
Right now if we commit a joint configuration to DB, there is no way
back. The only way to get the clean mconf is to continue the migration.
The RFC also described an abort mechanism, which allows to abort current
migration and revert mconf change. It might be needed if the migration
is stuck and cannot have any progress, e.g. if the sk we are migrating
to went down during the migration. This PR implements this abort
algorithm.

- Closes: https://databricks.atlassian.net/browse/LKB-899
- Closes: https://github.com/neondatabase/neon/issues/12549

## Summary of changes
- Implement `safekeeper_migrate_abort` handler with the algorithm
described in RFC
- Add `timeline-safekeeper-migrate-abort` subcommand to `storcon_cli`
- Add test for the migration abort algorithm.
2025-07-31 12:40:32 +00:00
Mikhail
df4e37b7cc Report timespans for promotion and prewarm (#12730)
- Return sub-actions time spans for prewarm, prewarm offload, and
promotion in http handlers.
- Set `synchronous_standby_names=walproposer` for promoted endpoints.
Otherwise, walproposer on promoted standby ignores reply from safekeeper
and is stuck on lsn COMMIT eternally.
2025-07-31 11:51:19 +00:00
Heikki Linnakangas
b4a63e0a34 Fix how neon.stripe_size option is set in postgresql.conf file (#12776)
Commit 1dce2a9e74 changed how the `neon.pageserver_connstring` setting
is formed, but it messed up setting the `neon.stripe_size` setting so
that it was set twice. That got mixed up during development of the
patch, as commit 7fef4435c1 landed first and was merged incorrectly.
2025-07-31 11:46:57 +00:00
Erik Grinaker
f8fc0bf3c0 neon_local: use doc comments for help texts (#12270)
Clap automatically uses doc comments as help/about texts. Doc comments
are strictly better, since they're also used e.g. for IDE documentation,
and are better formatted.

This patch updates all `neon_local` commands to use doc comments
(courtesy of GPT-o3).
2025-07-31 10:25:33 +00:00
52 changed files with 960 additions and 1537 deletions

215
Cargo.lock generated
View File

@@ -173,45 +173,6 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048"
dependencies = [
"asn1-rs-derive",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "asn1-rs-derive"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
"synstructure",
]
[[package]]
name = "asn1-rs-impl"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@@ -346,30 +307,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be"
dependencies = [
"aws-lc-sys",
"untrusted 0.7.1",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.30.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "aws-runtime"
version = "1.4.4"
@@ -1031,29 +968,6 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
"which",
]
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -1346,15 +1260,6 @@ dependencies = [
"replace_with",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1587,7 +1492,6 @@ dependencies = [
"postgres_connection",
"regex",
"reqwest",
"rsa",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
@@ -1607,7 +1511,6 @@ dependencies = [
"utils",
"whoami",
"workspace_hack",
"x509-parser",
]
[[package]]
@@ -1933,20 +1836,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "der-parser"
version = "9.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
dependencies = [
"asn1-rs",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "der_derive"
version = "0.7.3"
@@ -2103,12 +1992,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2226,7 +2109,6 @@ dependencies = [
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"postgres_backend",
"prometheus",
"rand 0.9.1",
"remote_storage",
@@ -2509,12 +2391,6 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2964,15 +2840,6 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -3747,12 +3614,6 @@ dependencies = [
"spin",
]
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
@@ -4328,15 +4189,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "oid-registry"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9"
dependencies = [
"asn1-rs",
]
[[package]]
name = "once_cell"
version = "1.20.2"
@@ -5220,7 +5072,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen 0.71.1",
"bindgen",
"bytes",
"crc32c",
"criterion",
@@ -5885,7 +5737,6 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"aws-lc-rs",
"pem",
"ring",
"rustls-pki-types",
@@ -6201,7 +6052,7 @@ dependencies = [
"cfg-if",
"getrandom 0.2.11",
"libc",
"untrusted 0.9.0",
"untrusted",
"windows-sys 0.52.0",
]
@@ -6322,15 +6173,6 @@ dependencies = [
"semver",
]
[[package]]
name = "rusticata-macros"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [
"nom",
]
[[package]]
name = "rustix"
version = "0.38.41"
@@ -6458,7 +6300,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring",
"untrusted 0.9.0",
"untrusted",
]
[[package]]
@@ -6469,7 +6311,7 @@ checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted 0.9.0",
"untrusted",
]
[[package]]
@@ -6480,7 +6322,7 @@ checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted 0.9.0",
"untrusted",
]
[[package]]
@@ -6642,7 +6484,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring",
"untrusted 0.9.0",
"untrusted",
]
[[package]]
@@ -7225,7 +7067,6 @@ dependencies = [
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
"jsonwebtoken",
"lasso",
"measured",
"metrics",
@@ -8400,12 +8241,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8505,7 +8340,6 @@ dependencies = [
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"oid-registry",
"once_cell",
"pem",
"pin-project-lite",
@@ -8513,10 +8347,7 @@ dependencies = [
"pprof",
"pq_proto",
"rand 0.9.1",
"rcgen",
"regex",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"scopeguard",
"sentry",
"serde",
@@ -8537,7 +8368,6 @@ dependencies = [
"tracing-utils",
"uuid",
"walkdir",
"x509-parser",
]
[[package]]
@@ -8652,7 +8482,7 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen 0.71.1",
"bindgen",
"postgres_ffi",
"utils",
]
@@ -8817,18 +8647,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "whoami"
version = "1.5.1"
@@ -9201,7 +9019,6 @@ dependencies = [
"der 0.7.8",
"deranged",
"digest",
"displaydoc",
"ecdsa 0.16.9",
"either",
"elliptic-curve 0.13.8",
@@ -9249,7 +9066,6 @@ dependencies = [
"prost 0.13.5",
"quote",
"rand 0.9.1",
"rcgen",
"regex",
"regex-automata 0.4.9",
"regex-syntax 0.8.5",
@@ -9335,23 +9151,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "x509-parser"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69"
dependencies = [
"asn1-rs",
"data-encoding",
"der-parser",
"lazy_static",
"nom",
"oid-registry",
"rusticata-macros",
"thiserror 1.0.69",
"time",
]
[[package]]
name = "xattr"
version = "1.0.0"

View File

@@ -142,7 +142,6 @@ nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket"
notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.19"
oid-registry = "0.7.1"
once_cell = "1.13"
opentelemetry = "0.30"
opentelemetry_sdk = "0.30"
@@ -174,7 +173,6 @@ rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
rustls-split = "0.3"
scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
@@ -237,7 +235,6 @@ rustls-native-certs = "0.8"
whoami = "1.5.1"
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
x509-parser = "0.16"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"

View File

@@ -2780,7 +2780,7 @@ LIMIT 100",
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
LfcPrewarmState::Completed { .. }
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {

View File

@@ -7,19 +7,11 @@ use http::StatusCode;
use reqwest::Client;
use std::mem::replace;
use std::sync::Arc;
use std::time::Instant;
use tokio::{io::AsyncReadExt, select, spawn};
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
#[derive(serde::Serialize, Default)]
pub struct LfcPrewarmStateWithProgress {
#[serde(flatten)]
base: LfcPrewarmState,
total: i32,
prewarmed: i32,
skipped: i32,
}
/// A pair of url and a token to query endpoint storage for LFC prewarm-related tasks
struct EndpointStoragePair {
url: String,
@@ -28,7 +20,7 @@ struct EndpointStoragePair {
const KEY: &str = "lfc_state";
impl EndpointStoragePair {
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
/// endpoint_id is set to None while prewarming from other endpoint, see compute_promote.rs
/// If not None, takes precedence over pspec.spec.endpoint_id
fn from_spec_and_endpoint(
pspec: &crate::compute::ParsedSpec,
@@ -54,36 +46,8 @@ impl EndpointStoragePair {
}
impl ComputeNode {
// If prewarm failed, we want to get overall number of segments as well as done ones.
// However, this function should be reliable even if querying postgres failed.
pub async fn lfc_prewarm_state(&self) -> LfcPrewarmStateWithProgress {
info!("requesting LFC prewarm state from postgres");
let mut state = LfcPrewarmStateWithProgress::default();
{
state.base = self.state.lock().unwrap().lfc_prewarm_state.clone();
}
let client = match ComputeNode::get_maintenance_client(&self.tokio_conn_conf).await {
Ok(client) => client,
Err(err) => {
error!(%err, "connecting to postgres");
return state;
}
};
let row = match client
.query_one("select * from neon.get_prewarm_info()", &[])
.await
{
Ok(row) => row,
Err(err) => {
error!(%err, "querying LFC prewarm status");
return state;
}
};
state.total = row.try_get(0).unwrap_or_default();
state.prewarmed = row.try_get(1).unwrap_or_default();
state.skipped = row.try_get(2).unwrap_or_default();
state
pub async fn lfc_prewarm_state(&self) -> LfcPrewarmState {
self.state.lock().unwrap().lfc_prewarm_state.clone()
}
pub fn lfc_offload_state(&self) -> LfcOffloadState {
@@ -133,7 +97,6 @@ impl ComputeNode {
}
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(
&self,
from_endpoint: Option<String>,
@@ -148,6 +111,7 @@ impl ComputeNode {
fail::fail_point!("compute-prewarm", |_| bail!("compute-prewarm failpoint"));
info!(%url, "requesting LFC state from endpoint storage");
let mut now = Instant::now();
let request = Client::new().get(&url).bearer_auth(storage_token);
let response = select! {
_ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled),
@@ -160,6 +124,8 @@ impl ComputeNode {
StatusCode::NOT_FOUND => return Ok(LfcPrewarmState::Skipped),
status => bail!("{status} querying endpoint storage"),
}
let state_download_time_ms = now.elapsed().as_millis() as u32;
now = Instant::now();
let mut uncompressed = Vec::new();
let lfc_state = select! {
@@ -174,6 +140,8 @@ impl ComputeNode {
read = decoder.read_to_end(&mut uncompressed) => read
}
.context("decoding LFC state")?;
let uncompress_time_ms = now.elapsed().as_millis() as u32;
now = Instant::now();
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}");
@@ -196,15 +164,34 @@ impl ComputeNode {
}
.context("loading LFC state into postgres")
.map(|_| ())?;
let prewarm_time_ms = now.elapsed().as_millis() as u32;
Ok(LfcPrewarmState::Completed)
let row = client
.query_one("select * from neon.get_prewarm_info()", &[])
.await
.context("querying prewarm info")?;
let total = row.try_get(0).unwrap_or_default();
let prewarmed = row.try_get(1).unwrap_or_default();
let skipped = row.try_get(2).unwrap_or_default();
Ok(LfcPrewarmState::Completed {
total,
prewarmed,
skipped,
state_download_time_ms,
uncompress_time_ms,
prewarm_time_ms,
})
}
/// If offload request is ongoing, return false, true otherwise
pub fn offload_lfc(self: &Arc<Self>) -> bool {
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
if matches!(
replace(state, LfcOffloadState::Offloading),
LfcOffloadState::Offloading
) {
return false;
}
}
@@ -216,7 +203,10 @@ impl ComputeNode {
pub async fn offload_lfc_async(self: &Arc<Self>) {
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
if matches!(
replace(state, LfcOffloadState::Offloading),
LfcOffloadState::Offloading
) {
return;
}
}
@@ -234,7 +224,6 @@ impl ComputeNode {
LfcOffloadState::Failed { error }
}
};
self.state.lock().unwrap().lfc_offload_state = state;
}
@@ -242,6 +231,7 @@ impl ComputeNode {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from Postgres");
let mut now = Instant::now();
let row = ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
@@ -255,25 +245,36 @@ impl ComputeNode {
info!(%url, "empty LFC state, not exporting");
return Ok(LfcOffloadState::Skipped);
};
let state_query_time_ms = now.elapsed().as_millis() as u32;
now = Instant::now();
let mut compressed = Vec::new();
ZstdEncoder::new(state)
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compress_time_ms = now.elapsed().as_millis() as u32;
now = Instant::now();
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
info!(%url, "downloaded LFC state, compressed size {compressed_len}");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(LfcOffloadState::Completed),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Err(err) => Err(err).context("writing to endpoint storage"),
let response = request
.send()
.await
.context("writing to endpoint storage")?;
let state_upload_time_ms = now.elapsed().as_millis() as u32;
let status = response.status();
if status != StatusCode::OK {
bail!("request to endpoint storage failed: {status}");
}
Ok(LfcOffloadState::Completed {
compress_time_ms,
state_query_time_ms,
state_upload_time_ms,
})
}
pub fn cancel_prewarm(self: &Arc<Self>) {

View File

@@ -1,32 +1,24 @@
use crate::compute::ComputeNode;
use anyhow::{Context, Result, bail};
use anyhow::{Context, bail};
use compute_api::responses::{LfcPrewarmState, PromoteConfig, PromoteState};
use compute_api::spec::ComputeMode;
use itertools::Itertools;
use std::collections::HashMap;
use std::{sync::Arc, time::Duration};
use tokio::time::sleep;
use std::time::Instant;
use tracing::info;
use utils::lsn::Lsn;
impl ComputeNode {
/// Returns only when promote fails or succeeds. If a network error occurs
/// and http client disconnects, this does not stop promotion, and subsequent
/// calls block until promote finishes.
/// Returns only when promote fails or succeeds. If http client calling this function
/// disconnects, this does not stop promotion, and subsequent calls block until promote finishes.
/// Called by control plane on secondary after primary endpoint is terminated
/// Has a failpoint "compute-promotion"
pub async fn promote(self: &Arc<Self>, cfg: PromoteConfig) -> PromoteState {
let cloned = self.clone();
let promote_fn = async move || {
let Err(err) = cloned.promote_impl(cfg).await else {
return PromoteState::Completed;
};
tracing::error!(%err, "promoting");
PromoteState::Failed {
error: format!("{err:#}"),
pub async fn promote(self: &std::sync::Arc<Self>, cfg: PromoteConfig) -> PromoteState {
let this = self.clone();
let promote_fn = async move || match this.promote_impl(cfg).await {
Ok(state) => state,
Err(err) => {
tracing::error!(%err, "promoting replica");
let error = format!("{err:#}");
PromoteState::Failed { error }
}
};
let start_promotion = || {
let (tx, rx) = tokio::sync::watch::channel(PromoteState::NotPromoted);
tokio::spawn(async move { tx.send(promote_fn().await) });
@@ -34,36 +26,31 @@ impl ComputeNode {
};
let mut task;
// self.state is unlocked after block ends so we lock it in promote_impl
// and task.changed() is reached
// promote_impl locks self.state so we need to unlock it before calling task.changed()
{
task = self
.state
.lock()
.unwrap()
.promote_state
.get_or_insert_with(start_promotion)
.clone()
let promote_state = &mut self.state.lock().unwrap().promote_state;
task = promote_state.get_or_insert_with(start_promotion).clone()
}
if task.changed().await.is_err() {
let error = "promote sender dropped".to_string();
return PromoteState::Failed { error };
}
task.changed().await.expect("promote sender dropped");
task.borrow().clone()
}
async fn promote_impl(&self, mut cfg: PromoteConfig) -> Result<()> {
async fn promote_impl(&self, cfg: PromoteConfig) -> anyhow::Result<PromoteState> {
{
let state = self.state.lock().unwrap();
let mode = &state.pspec.as_ref().unwrap().spec.mode;
if *mode != ComputeMode::Replica {
bail!("{} is not replica", mode.to_type_str());
if *mode != compute_api::spec::ComputeMode::Replica {
bail!("compute mode \"{}\" is not replica", mode.to_type_str());
}
// we don't need to query Postgres so not self.lfc_prewarm_state()
match &state.lfc_prewarm_state {
LfcPrewarmState::NotPrewarmed | LfcPrewarmState::Prewarming => {
bail!("prewarm not requested or pending")
status @ (LfcPrewarmState::NotPrewarmed | LfcPrewarmState::Prewarming) => {
bail!("compute {status}")
}
LfcPrewarmState::Failed { error } => {
tracing::warn!(%error, "replica prewarm failed")
tracing::warn!(%error, "compute prewarm failed")
}
_ => {}
}
@@ -72,9 +59,10 @@ impl ComputeNode {
let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?;
let mut now = Instant::now();
let primary_lsn = cfg.wal_flush_lsn;
let mut last_wal_replay_lsn: Lsn = Lsn::INVALID;
let mut standby_lsn = utils::lsn::Lsn::INVALID;
const RETRIES: i32 = 20;
for i in 0..=RETRIES {
let row = client
@@ -82,16 +70,18 @@ impl ComputeNode {
.await
.context("getting last replay lsn")?;
let lsn: u64 = row.get::<usize, postgres_types::PgLsn>(0).into();
last_wal_replay_lsn = lsn.into();
if last_wal_replay_lsn >= primary_lsn {
standby_lsn = lsn.into();
if standby_lsn >= primary_lsn {
break;
}
info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}");
sleep(Duration::from_secs(1)).await;
info!(%standby_lsn, %primary_lsn, "catching up, try {i}");
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
if last_wal_replay_lsn < primary_lsn {
if standby_lsn < primary_lsn {
bail!("didn't catch up with primary in {RETRIES} retries");
}
let lsn_wait_time_ms = now.elapsed().as_millis() as u32;
now = Instant::now();
// using $1 doesn't work with ALTER SYSTEM SET
let safekeepers_sql = format!(
@@ -102,27 +92,33 @@ impl ComputeNode {
.query(&safekeepers_sql, &[])
.await
.context("setting safekeepers")?;
client
.query(
"ALTER SYSTEM SET synchronous_standby_names=walproposer",
&[],
)
.await
.context("setting synchronous_standby_names")?;
client
.query("SELECT pg_catalog.pg_reload_conf()", &[])
.await
.context("reloading postgres config")?;
#[cfg(feature = "testing")]
fail::fail_point!("compute-promotion", |_| {
bail!("promotion configured to fail because of a failpoint")
});
fail::fail_point!("compute-promotion", |_| bail!(
"compute-promotion failpoint"
));
let row = client
.query_one("SELECT * FROM pg_catalog.pg_promote()", &[])
.await
.context("pg_promote")?;
if !row.get::<usize, bool>(0) {
bail!("pg_promote() returned false");
bail!("pg_promote() failed");
}
let pg_promote_time_ms = now.elapsed().as_millis() as u32;
let now = Instant::now();
let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?;
let row = client
.query_one("SHOW transaction_read_only", &[])
.await
@@ -131,36 +127,47 @@ impl ComputeNode {
bail!("replica in read only mode after promotion");
}
// Already checked validity in http handler
#[allow(unused_mut)]
let mut new_pspec = crate::compute::ParsedSpec::try_from(cfg.spec).expect("invalid spec");
{
let mut state = self.state.lock().unwrap();
let spec = &mut state.pspec.as_mut().unwrap().spec;
spec.mode = ComputeMode::Primary;
let new_conf = cfg.spec.cluster.postgresql_conf.as_mut().unwrap();
let existing_conf = spec.cluster.postgresql_conf.as_ref().unwrap();
Self::merge_spec(new_conf, existing_conf);
// Local setup has different ports for pg process (port=) for primary and secondary.
// Primary is stopped so we need secondary's "port" value
#[cfg(feature = "testing")]
{
let old_spec = &state.pspec.as_ref().unwrap().spec;
let Some(old_conf) = old_spec.cluster.postgresql_conf.as_ref() else {
bail!("pspec.spec.cluster.postgresql_conf missing for endpoint");
};
let set: std::collections::HashMap<&str, &str> = old_conf
.split_terminator('\n')
.map(|e| e.split_once("=").expect("invalid item"))
.collect();
let Some(new_conf) = new_pspec.spec.cluster.postgresql_conf.as_mut() else {
bail!("pspec.spec.cluster.postgresql_conf missing for supplied config");
};
new_conf.push_str(&format!("port={}\n", set["port"]));
}
tracing::debug!("applied spec: {:#?}", new_pspec.spec);
if self.params.lakebase_mode {
ComputeNode::set_spec(&self.params, &mut state, new_pspec);
} else {
state.pspec = Some(new_pspec);
}
}
info!("applied new spec, reconfiguring as primary");
self.reconfigure()
}
self.reconfigure()?;
let reconfigure_time_ms = now.elapsed().as_millis() as u32;
/// Merge old and new Postgres conf specs to apply on secondary.
/// Change new spec's port and safekeepers since they are supplied
/// differenly
fn merge_spec(new_conf: &mut String, existing_conf: &str) {
let mut new_conf_set: HashMap<&str, &str> = new_conf
.split_terminator('\n')
.map(|e| e.split_once("=").expect("invalid item"))
.collect();
new_conf_set.remove("neon.safekeepers");
let existing_conf_set: HashMap<&str, &str> = existing_conf
.split_terminator('\n')
.map(|e| e.split_once("=").expect("invalid item"))
.collect();
new_conf_set.insert("port", existing_conf_set["port"]);
*new_conf = new_conf_set
.iter()
.map(|(k, v)| format!("{k}={v}"))
.join("\n");
Ok(PromoteState::Completed {
lsn_wait_time_ms,
pg_promote_time_ms,
reconfigure_time_ms,
})
}
}

View File

@@ -65,14 +65,19 @@ pub fn write_postgres_conf(
writeln!(file, "{conf}")?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
writeln!(file)?;
if let Some(conninfo) = &spec.pageserver_connection_info {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = conninfo.stripe_size {
writeln!(
file,
"# from compute spec's pageserver_connection_info.stripe_size field"
)?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
let mut libpq_urls: Option<Vec<String>> = Some(Vec::new());
let num_shards = if conninfo.shard_count.0 == 0 {
1 // unsharded, treat it as a single shard
@@ -110,7 +115,7 @@ pub fn write_postgres_conf(
if let Some(libpq_urls) = libpq_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
"# derived from compute spec's pageserver_connection_info field"
)?;
writeln!(
file,
@@ -120,24 +125,16 @@ pub fn write_postgres_conf(
} else {
writeln!(file, "# no neon.pageserver_connstring")?;
}
if let Some(stripe_size) = conninfo.stripe_size {
writeln!(
file,
"# from compute spec's pageserver_conninfo.stripe_size field"
)?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
} else {
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "# from compute spec's pageserver_connstring field")?;
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "# from compute spec's shard_stripe_size field")?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "# from compute spec's pageserver_connstring field")?;
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
}
if !spec.safekeeper_connstrings.is_empty() {

View File

@@ -617,9 +617,6 @@ components:
type: object
required:
- status
- total
- prewarmed
- skipped
properties:
status:
description: LFC prewarm status
@@ -637,6 +634,15 @@ components:
skipped:
description: Pages processed but not prewarmed
type: integer
state_download_time_ms:
description: Time it takes to download LFC state to compute
type: integer
uncompress_time_ms:
description: Time it takes to uncompress LFC state
type: integer
prewarm_time_ms:
description: Time it takes to prewarm LFC state in Postgres
type: integer
LfcOffloadState:
type: object
@@ -650,6 +656,16 @@ components:
error:
description: LFC offload error, if any
type: string
state_query_time_ms:
description: Time it takes to get LFC state from Postgres
type: integer
compress_time_ms:
description: Time it takes to compress LFC state
type: integer
state_upload_time_ms:
description: Time it takes to upload LFC state to endpoint storage
type: integer
PromoteState:
type: object
@@ -663,6 +679,15 @@ components:
error:
description: Promote error, if any
type: string
lsn_wait_time_ms:
description: Time it takes for secondary to catch up with primary WAL flush LSN
type: integer
pg_promote_time_ms:
description: Time it takes to call pg_promote on secondary
type: integer
reconfigure_time_ms:
description: Time it takes to reconfigure promoted secondary
type: integer
SetRoleGrantsRequest:
type: object

View File

@@ -1,12 +1,11 @@
use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::LfcOffloadState;
use compute_api::responses::{LfcOffloadState, LfcPrewarmState};
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json<LfcPrewarmStateWithProgress> {
pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json<LfcPrewarmState> {
Json(compute.lfc_prewarm_state().await)
}

View File

@@ -1,11 +1,22 @@
use crate::http::JsonResponse;
use axum::extract::Json;
use compute_api::responses::PromoteConfig;
use http::StatusCode;
pub(in crate::http) async fn promote(
compute: axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>,
Json(cfg): Json<compute_api::responses::PromoteConfig>,
Json(cfg): Json<PromoteConfig>,
) -> axum::response::Response {
// Return early at the cost of extra parsing spec
let pspec = match crate::compute::ParsedSpec::try_from(cfg.spec) {
Ok(p) => p,
Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e),
};
let cfg = PromoteConfig {
spec: pspec.spec,
wal_flush_lsn: cfg.wal_flush_lsn,
};
let state = compute.promote(cfg).await;
if let compute_api::responses::PromoteState::Failed { error: _ } = state {
return JsonResponse::create_response(StatusCode::INTERNAL_SERVER_ERROR, state);

View File

@@ -407,8 +407,8 @@ fn get_database_stats(cli: &mut Client) -> anyhow::Result<(f64, i64)> {
// like `postgres_exporter` use it to query Postgres statistics.
// Use explicit 8 bytes type casts to match Rust types.
let stats = cli.query_one(
"SELECT pg_catalog.coalesce(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
pg_catalog.coalesce(pg_catalog.sum(sessions), 0)::pg_catalog.bigint AS total_sessions
"SELECT COALESCE(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
COALESCE(pg_catalog.sum(sessions), 0)::pg_catalog.int8 AS total_sessions
FROM pg_catalog.pg_stat_database
WHERE datname NOT IN (
'postgres',

View File

@@ -46,5 +46,3 @@ endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true
x509-parser.workspace = true
rsa = "0.9"

View File

@@ -71,8 +71,9 @@ const DEFAULT_PG_VERSION_NUM: &str = "17";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";
/// Neon CLI.
#[derive(clap::Parser)]
#[command(version = GIT_VERSION, about, name = "Neon CLI")]
#[command(version = GIT_VERSION, name = "Neon CLI")]
struct Cli {
#[command(subcommand)]
command: NeonLocalCmd,
@@ -107,30 +108,31 @@ enum NeonLocalCmd {
Stop(StopCmdArgs),
}
/// Initialize a new Neon repository, preparing configs for services to start with.
#[derive(clap::Args)]
#[clap(about = "Initialize a new Neon repository, preparing configs for services to start with")]
struct InitCmdArgs {
#[clap(long, help("How many pageservers to create (default 1)"))]
/// How many pageservers to create (default 1).
#[clap(long)]
num_pageservers: Option<u16>,
#[clap(long)]
config: Option<PathBuf>,
#[clap(long, help("Force initialization even if the repository is not empty"))]
/// Force initialization even if the repository is not empty.
#[clap(long, default_value = "must-not-exist")]
#[arg(value_parser)]
#[clap(default_value = "must-not-exist")]
force: InitForceMode,
}
/// Start pageserver and safekeepers.
#[derive(clap::Args)]
#[clap(about = "Start pageserver and safekeepers")]
struct StartCmdArgs {
#[clap(long = "start-timeout", default_value = "10s")]
timeout: humantime::Duration,
}
/// Stop pageserver and safekeepers.
#[derive(clap::Args)]
#[clap(about = "Stop pageserver and safekeepers")]
struct StopCmdArgs {
#[arg(value_enum)]
#[clap(long, default_value_t = StopMode::Fast)]
@@ -143,8 +145,8 @@ enum StopMode {
Immediate,
}
/// Manage tenants.
#[derive(clap::Subcommand)]
#[clap(about = "Manage tenants")]
enum TenantCmd {
List,
Create(TenantCreateCmdArgs),
@@ -155,38 +157,36 @@ enum TenantCmd {
#[derive(clap::Args)]
struct TenantCreateCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(
long,
help = "Use a specific timeline id when creating a tenant and its initial timeline"
)]
/// Use a specific timeline id when creating a tenant and its initial timeline.
#[clap(long)]
timeline_id: Option<TimelineId>,
#[clap(short = 'c')]
config: Vec<String>,
/// Postgres version to use for the initial timeline.
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version to use for the initial timeline")]
#[clap(long)]
pg_version: PgMajorVersion,
#[clap(
long,
help = "Use this tenant in future CLI commands where tenant_id is needed, but not specified"
)]
/// Use this tenant in future CLI commands where tenant_id is needed, but not specified.
#[clap(long)]
set_default: bool,
#[clap(long, help = "Number of shards in the new tenant")]
/// Number of shards in the new tenant.
#[clap(long)]
#[arg(default_value_t = 0)]
shard_count: u8,
#[clap(long, help = "Sharding stripe size in pages")]
/// Sharding stripe size in pages.
#[clap(long)]
shard_stripe_size: Option<u32>,
#[clap(long, help = "Placement policy shards in this tenant")]
/// Placement policy shards in this tenant.
#[clap(long)]
#[arg(value_parser = parse_placement_policy)]
placement_policy: Option<PlacementPolicy>,
}
@@ -195,44 +195,35 @@ fn parse_placement_policy(s: &str) -> anyhow::Result<PlacementPolicy> {
Ok(serde_json::from_str::<PlacementPolicy>(s)?)
}
/// Set a particular tenant as default in future CLI commands where tenant_id is needed, but not
/// specified.
#[derive(clap::Args)]
#[clap(
about = "Set a particular tenant as default in future CLI commands where tenant_id is needed, but not specified"
)]
struct TenantSetDefaultCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: TenantId,
}
#[derive(clap::Args)]
struct TenantConfigCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(short = 'c')]
config: Vec<String>,
}
/// Import a tenant that is present in remote storage, and create branches for its timelines.
#[derive(clap::Args)]
#[clap(
about = "Import a tenant that is present in remote storage, and create branches for its timelines"
)]
struct TenantImportCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: TenantId,
}
/// Manage timelines.
#[derive(clap::Subcommand)]
#[clap(about = "Manage timelines")]
enum TimelineCmd {
List(TimelineListCmdArgs),
Branch(TimelineBranchCmdArgs),
@@ -240,98 +231,87 @@ enum TimelineCmd {
Import(TimelineImportCmdArgs),
}
/// List all timelines available to this pageserver.
#[derive(clap::Args)]
#[clap(about = "List all timelines available to this pageserver")]
struct TimelineListCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_shard_id: Option<TenantShardId>,
}
/// Create a new timeline, branching off from another timeline.
#[derive(clap::Args)]
#[clap(about = "Create a new timeline, branching off from another timeline")]
struct TimelineBranchCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(long, help = "New timeline's ID")]
/// New timeline's ID, as a 32-byte hexadecimal string.
#[clap(long)]
timeline_id: Option<TimelineId>,
#[clap(long, help = "Human-readable alias for the new timeline")]
/// Human-readable alias for the new timeline.
#[clap(long)]
branch_name: String,
#[clap(
long,
help = "Use last Lsn of another timeline (and its data) as base when creating the new timeline. The timeline gets resolved by its branch name."
)]
/// Use last Lsn of another timeline (and its data) as base when creating the new timeline. The
/// timeline gets resolved by its branch name.
#[clap(long)]
ancestor_branch_name: Option<String>,
#[clap(
long,
help = "When using another timeline as base, use a specific Lsn in it instead of the latest one"
)]
/// When using another timeline as base, use a specific Lsn in it instead of the latest one.
#[clap(long)]
ancestor_start_lsn: Option<Lsn>,
}
/// Create a new blank timeline.
#[derive(clap::Args)]
#[clap(about = "Create a new blank timeline")]
struct TimelineCreateCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(long, help = "New timeline's ID")]
/// New timeline's ID, as a 32-byte hexadecimal string.
#[clap(long)]
timeline_id: Option<TimelineId>,
#[clap(long, help = "Human-readable alias for the new timeline")]
/// Human-readable alias for the new timeline.
#[clap(long)]
branch_name: String,
/// Postgres version.
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
#[clap(long)]
pg_version: PgMajorVersion,
}
/// Import a timeline from a basebackup directory.
#[derive(clap::Args)]
#[clap(about = "Import timeline from a basebackup directory")]
struct TimelineImportCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(long, help = "New timeline's ID")]
/// New timeline's ID, as a 32-byte hexadecimal string.
#[clap(long)]
timeline_id: TimelineId,
#[clap(long, help = "Human-readable alias for the new timeline")]
/// Human-readable alias for the new timeline.
#[clap(long)]
branch_name: String,
#[clap(long, help = "Basebackup tarfile to import")]
/// Basebackup tarfile to import.
#[clap(long)]
base_tarfile: PathBuf,
#[clap(long, help = "Lsn the basebackup starts at")]
/// LSN the basebackup starts at.
#[clap(long)]
base_lsn: Lsn,
#[clap(long, help = "Wal to add after base")]
/// WAL to add after base.
#[clap(long)]
wal_tarfile: Option<PathBuf>,
#[clap(long, help = "Lsn the basebackup ends at")]
/// LSN the basebackup ends at.
#[clap(long)]
end_lsn: Option<Lsn>,
/// Postgres version of the basebackup being imported.
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version of the backup being imported")]
#[clap(long)]
pg_version: PgMajorVersion,
}
/// Manage pageservers.
#[derive(clap::Subcommand)]
#[clap(about = "Manage pageservers")]
enum PageserverCmd {
Status(PageserverStatusCmdArgs),
Start(PageserverStartCmdArgs),
@@ -339,223 +319,202 @@ enum PageserverCmd {
Restart(PageserverRestartCmdArgs),
}
/// Show status of a local pageserver.
#[derive(clap::Args)]
#[clap(about = "Show status of a local pageserver")]
struct PageserverStatusCmdArgs {
#[clap(long = "id", help = "pageserver id")]
/// Pageserver ID.
#[clap(long = "id")]
pageserver_id: Option<NodeId>,
}
/// Start local pageserver.
#[derive(clap::Args)]
#[clap(about = "Start local pageserver")]
struct PageserverStartCmdArgs {
#[clap(long = "id", help = "pageserver id")]
/// Pageserver ID.
#[clap(long = "id")]
pageserver_id: Option<NodeId>,
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Stop local pageserver.
#[derive(clap::Args)]
#[clap(about = "Stop local pageserver")]
struct PageserverStopCmdArgs {
#[clap(long = "id", help = "pageserver id")]
/// Pageserver ID.
#[clap(long = "id")]
pageserver_id: Option<NodeId>,
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
/// If 'immediate', don't flush repository data at shutdown
#[clap(short = 'm')]
#[arg(value_enum, default_value = "fast")]
stop_mode: StopMode,
}
/// Restart local pageserver.
#[derive(clap::Args)]
#[clap(about = "Restart local pageserver")]
struct PageserverRestartCmdArgs {
#[clap(long = "id", help = "pageserver id")]
/// Pageserver ID.
#[clap(long = "id")]
pageserver_id: Option<NodeId>,
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Manage storage controller.
#[derive(clap::Subcommand)]
#[clap(about = "Manage storage controller")]
enum StorageControllerCmd {
Start(StorageControllerStartCmdArgs),
Stop(StorageControllerStopCmdArgs),
}
/// Start storage controller.
#[derive(clap::Args)]
#[clap(about = "Start storage controller")]
struct StorageControllerStartCmdArgs {
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
#[clap(
long,
help = "Identifier used to distinguish storage controller instances"
)]
/// Identifier used to distinguish storage controller instances.
#[clap(long)]
#[arg(default_value_t = 1)]
instance_id: u8,
#[clap(
long,
help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)"
)]
/// Base port for the storage controller instance identified by instance-id (defaults to
/// pageserver cplane api).
#[clap(long)]
base_port: Option<u16>,
#[clap(
long,
help = "Whether the storage controller should handle pageserver-reported local disk loss events."
)]
/// Whether the storage controller should handle pageserver-reported local disk loss events.
#[clap(long)]
handle_ps_local_disk_loss: Option<bool>,
}
/// Stop storage controller.
#[derive(clap::Args)]
#[clap(about = "Stop storage controller")]
struct StorageControllerStopCmdArgs {
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
/// If 'immediate', don't flush repository data at shutdown
#[clap(short = 'm')]
#[arg(value_enum, default_value = "fast")]
stop_mode: StopMode,
#[clap(
long,
help = "Identifier used to distinguish storage controller instances"
)]
/// Identifier used to distinguish storage controller instances.
#[clap(long)]
#[arg(default_value_t = 1)]
instance_id: u8,
}
/// Manage storage broker.
#[derive(clap::Subcommand)]
#[clap(about = "Manage storage broker")]
enum StorageBrokerCmd {
Start(StorageBrokerStartCmdArgs),
Stop(StorageBrokerStopCmdArgs),
}
/// Start broker.
#[derive(clap::Args)]
#[clap(about = "Start broker")]
struct StorageBrokerStartCmdArgs {
#[clap(short = 't', long, help = "timeout until we fail the command")]
#[arg(default_value = "10s")]
/// Timeout until we fail the command.
#[clap(short = 't', long, default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Stop broker.
#[derive(clap::Args)]
#[clap(about = "stop broker")]
struct StorageBrokerStopCmdArgs {
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
/// If 'immediate', don't flush repository data on shutdown.
#[clap(short = 'm')]
#[arg(value_enum, default_value = "fast")]
stop_mode: StopMode,
}
/// Manage safekeepers.
#[derive(clap::Subcommand)]
#[clap(about = "Manage safekeepers")]
enum SafekeeperCmd {
Start(SafekeeperStartCmdArgs),
Stop(SafekeeperStopCmdArgs),
Restart(SafekeeperRestartCmdArgs),
}
/// Manage object storage.
#[derive(clap::Subcommand)]
#[clap(about = "Manage object storage")]
enum EndpointStorageCmd {
Start(EndpointStorageStartCmd),
Stop(EndpointStorageStopCmd),
}
/// Start object storage.
#[derive(clap::Args)]
#[clap(about = "Start object storage")]
struct EndpointStorageStartCmd {
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Stop object storage.
#[derive(clap::Args)]
#[clap(about = "Stop object storage")]
struct EndpointStorageStopCmd {
/// If 'immediate', don't flush repository data on shutdown.
#[clap(short = 'm')]
#[arg(value_enum, default_value = "fast")]
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
stop_mode: StopMode,
}
/// Start local safekeeper.
#[derive(clap::Args)]
#[clap(about = "Start local safekeeper")]
struct SafekeeperStartCmdArgs {
#[clap(help = "safekeeper id")]
/// Safekeeper ID.
#[arg(default_value_t = NodeId(1))]
id: NodeId,
#[clap(
short = 'e',
long = "safekeeper-extra-opt",
help = "Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo"
)]
/// Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo.
#[clap(short = 'e', long = "safekeeper-extra-opt")]
extra_opt: Vec<String>,
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Stop local safekeeper.
#[derive(clap::Args)]
#[clap(about = "Stop local safekeeper")]
struct SafekeeperStopCmdArgs {
#[clap(help = "safekeeper id")]
/// Safekeeper ID.
#[arg(default_value_t = NodeId(1))]
id: NodeId,
/// If 'immediate', don't flush repository data on shutdown.
#[arg(value_enum, default_value = "fast")]
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
#[clap(short = 'm')]
stop_mode: StopMode,
}
/// Restart local safekeeper.
#[derive(clap::Args)]
#[clap(about = "Restart local safekeeper")]
struct SafekeeperRestartCmdArgs {
#[clap(help = "safekeeper id")]
/// Safekeeper ID.
#[arg(default_value_t = NodeId(1))]
id: NodeId,
/// If 'immediate', don't flush repository data on shutdown.
#[arg(value_enum, default_value = "fast")]
#[clap(
short = 'm',
help = "If 'immediate', don't flush repository data at shutdown"
)]
#[clap(short = 'm')]
stop_mode: StopMode,
#[clap(
short = 'e',
long = "safekeeper-extra-opt",
help = "Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo"
)]
/// Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo.
#[clap(short = 'e', long = "safekeeper-extra-opt")]
extra_opt: Vec<String>,
#[clap(short = 't', long, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long)]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
}
/// Manage Postgres instances.
#[derive(clap::Subcommand)]
#[clap(about = "Manage Postgres instances")]
enum EndpointCmd {
List(EndpointListCmdArgs),
Create(EndpointCreateCmdArgs),
@@ -567,33 +526,27 @@ enum EndpointCmd {
GenerateJwt(EndpointGenerateJwtCmdArgs),
}
/// List endpoints.
#[derive(clap::Args)]
#[clap(about = "List endpoints")]
struct EndpointListCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_shard_id: Option<TenantShardId>,
}
/// Create a compute endpoint.
#[derive(clap::Args)]
#[clap(about = "Create a compute endpoint")]
struct EndpointCreateCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint ID.
endpoint_id: Option<String>,
#[clap(long, help = "Name of the branch the endpoint will run on")]
/// Name of the branch the endpoint will run on.
#[clap(long)]
branch_name: Option<String>,
#[clap(
long,
help = "Specify Lsn on the timeline to start from. By default, end of the timeline would be used"
)]
/// Specify LSN on the timeline to start from. By default, end of the timeline would be used.
#[clap(long)]
lsn: Option<Lsn>,
#[clap(long)]
pg_port: Option<u16>,
@@ -604,16 +557,13 @@ struct EndpointCreateCmdArgs {
#[clap(long = "pageserver-id")]
endpoint_pageserver_id: Option<NodeId>,
#[clap(
long,
help = "Don't do basebackup, create endpoint directory with only config files",
action = clap::ArgAction::Set,
default_value_t = false
)]
/// Don't do basebackup, create endpoint directory with only config files.
#[clap(long, action = clap::ArgAction::Set, default_value_t = false)]
config_only: bool,
/// Postgres version.
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
#[clap(long)]
pg_version: PgMajorVersion,
/// Use gRPC to communicate with Pageservers, by generating grpc:// connstrings.
@@ -624,170 +574,140 @@ struct EndpointCreateCmdArgs {
#[clap(long)]
grpc: bool,
#[clap(
long,
help = "If set, the node will be a hot replica on the specified timeline",
action = clap::ArgAction::Set,
default_value_t = false
)]
/// If set, the node will be a hot replica on the specified timeline.
#[clap(long, action = clap::ArgAction::Set, default_value_t = false)]
hot_standby: bool,
#[clap(long, help = "If set, will set up the catalog for neon_superuser")]
/// If set, will set up the catalog for neon_superuser.
#[clap(long)]
update_catalog: bool,
#[clap(
long,
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
/// Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but
/// useful for tests.
#[clap(long)]
allow_multiple: bool,
/// Only allow changing it on creation
#[clap(long, help = "Name of the privileged role for the endpoint")]
/// Name of the privileged role for the endpoint.
// Only allow changing it on creation.
#[clap(long)]
privileged_role_name: Option<String>,
}
/// Start Postgres. If the endpoint doesn't exist yet, it is created.
#[derive(clap::Args)]
#[clap(about = "Start postgres. If the endpoint doesn't exist yet, it is created.")]
struct EndpointStartCmdArgs {
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint ID.
endpoint_id: String,
/// Pageserver ID.
#[clap(long = "pageserver-id")]
endpoint_pageserver_id: Option<NodeId>,
#[clap(
long,
help = "Safekeepers membership generation to prefix neon.safekeepers with. Normally neon_local sets it on its own, but this option allows to override. Non zero value forces endpoint to use membership configurations."
)]
/// Safekeepers membership generation to prefix neon.safekeepers with.
#[clap(long)]
safekeepers_generation: Option<u32>,
#[clap(
long,
help = "List of safekeepers endpoint will talk to. Normally neon_local chooses them on its own, but this option allows to override."
)]
/// List of safekeepers endpoint will talk to.
#[clap(long)]
safekeepers: Option<String>,
#[clap(
long,
help = "Configure the remote extensions storage proxy gateway URL to request for extensions.",
alias = "remote-ext-config"
)]
/// Configure the remote extensions storage proxy gateway URL to request for extensions.
#[clap(long, alias = "remote-ext-config")]
remote_ext_base_url: Option<String>,
#[clap(
long,
help = "If set, will create test user `user` and `neondb` database. Requires `update-catalog = true`"
)]
/// If set, will create test user `user` and `neondb` database. Requires `update-catalog = true`
#[clap(long)]
create_test_user: bool,
#[clap(
long,
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
/// Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but
/// useful for tests.
#[clap(long)]
allow_multiple: bool,
#[clap(short = 't', long, value_parser= humantime::parse_duration, help = "timeout until we fail the command")]
/// Timeout until we fail the command.
#[clap(short = 't', long, value_parser= humantime::parse_duration)]
#[arg(default_value = "90s")]
start_timeout: Duration,
#[clap(
long,
help = "Download LFC cache from endpoint storage on endpoint startup",
default_value = "false"
)]
/// Download LFC cache from endpoint storage on endpoint startup
#[clap(long, default_value = "false")]
autoprewarm: bool,
#[clap(long, help = "Upload LFC cache to endpoint storage periodically")]
/// Upload LFC cache to endpoint storage periodically
#[clap(long)]
offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
#[clap(
long,
help = "Run in development mode, skipping VM-specific operations like process termination",
action = clap::ArgAction::SetTrue
)]
/// Run in development mode, skipping VM-specific operations like process termination
#[clap(long, action = clap::ArgAction::SetTrue)]
dev: bool,
}
/// Reconfigure an endpoint.
#[derive(clap::Args)]
#[clap(about = "Reconfigure an endpoint")]
struct EndpointReconfigureCmdArgs {
#[clap(
long = "tenant-id",
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant id. Represented as a hexadecimal string 32 symbols length
#[clap(long = "tenant-id")]
tenant_id: Option<TenantId>,
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint ID.
endpoint_id: String,
/// Pageserver ID.
#[clap(long = "pageserver-id")]
endpoint_pageserver_id: Option<NodeId>,
#[clap(long)]
safekeepers: Option<String>,
}
/// Refresh the endpoint's configuration by forcing it reload it's spec
#[derive(clap::Args)]
#[clap(about = "Refresh the endpoint's configuration by forcing it reload it's spec")]
struct EndpointRefreshConfigurationArgs {
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint id
endpoint_id: String,
}
/// Stop an endpoint.
#[derive(clap::Args)]
#[clap(about = "Stop an endpoint")]
struct EndpointStopCmdArgs {
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint ID.
endpoint_id: String,
#[clap(
long,
help = "Also delete data directory (now optional, should be default in future)"
)]
/// Also delete data directory (now optional, should be default in future).
#[clap(long)]
destroy: bool,
#[clap(long, help = "Postgres shutdown mode")]
/// Postgres shutdown mode, passed to `pg_ctl -m <mode>`.
#[clap(long)]
#[clap(default_value = "fast")]
mode: EndpointTerminateMode,
}
/// Update the pageservers in the spec file of the compute endpoint
#[derive(clap::Args)]
#[clap(about = "Update the pageservers in the spec file of the compute endpoint")]
struct EndpointUpdatePageserversCmdArgs {
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint id
endpoint_id: String,
#[clap(short = 'p', long, help = "Specified pageserver id")]
/// Specified pageserver id
#[clap(short = 'p', long)]
pageserver_id: Option<NodeId>,
}
/// Generate a JWT for an endpoint.
#[derive(clap::Args)]
#[clap(about = "Generate a JWT for an endpoint")]
struct EndpointGenerateJwtCmdArgs {
#[clap(help = "Postgres endpoint id")]
/// Postgres endpoint ID.
endpoint_id: String,
#[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
/// Scope to generate the JWT with.
#[clap(short = 's', long, value_parser = ComputeClaimsScope::from_str)]
scope: Option<ComputeClaimsScope>,
}
/// Manage neon_local branch name mappings.
#[derive(clap::Subcommand)]
#[clap(about = "Manage neon_local branch name mappings")]
enum MappingsCmd {
Map(MappingsMapCmdArgs),
}
/// Create new mapping which cannot exist already.
#[derive(clap::Args)]
#[clap(about = "Create new mapping which cannot exist already")]
struct MappingsMapCmdArgs {
#[clap(
long,
help = "Tenant id. Represented as a hexadecimal string 32 symbols length"
)]
/// Tenant ID, as a 32-byte hexadecimal string.
#[clap(long)]
tenant_id: TenantId,
#[clap(
long,
help = "Timeline id. Represented as a hexadecimal string 32 symbols length"
)]
/// Timeline ID, as a 32-byte hexadecimal string.
#[clap(long)]
timeline_id: TimelineId,
#[clap(long, help = "Branch name to give to the timeline")]
/// Branch name to give to the timeline.
#[clap(long)]
branch_name: String,
}
@@ -1049,7 +969,6 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
// User (likely interactive) did not provide a description of the environment, give them the default
NeonLocalInitConf {
control_plane_api: Some(DEFAULT_PAGESERVER_CONTROL_PLANE_API.parse().unwrap()),
auth_token_type: AuthType::NeonJWT,
broker: NeonBroker {
listen_addr: Some(DEFAULT_BROKER_ADDR.parse().unwrap()),
listen_https_addr: None,
@@ -1585,10 +1504,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
pageserver_conninfo.prefer_protocol = prefer_protocol;
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(
ps_conf.pg_auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
) {
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
let claims = Claims::new(Some(endpoint.tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -37,8 +37,18 @@
//! <other PostgreSQL files>
//! ```
//!
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -56,30 +66,20 @@ pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionI
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse, RSAKeyParameters, RSAKeyType,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
};
use nix::sys::signal::{Signal, kill};
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use rsa::{RsaPublicKey, pkcs1::DecodeRsaPublicKey, traits::PublicKeyParts};
use safekeeper_api::PgMajorVersion;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use std::collections::{BTreeMap, HashMap};
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::debug;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use x509_parser::parse_x509_certificate;
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
@@ -161,76 +161,23 @@ impl ComputeControlPlane {
.unwrap_or(self.base_port)
}
// BEGIN HADRON
/// Extract SubjectPublicKeyInfo from a PEM that can be either a X509 certificate or a public key
fn extract_spki_from_pem(pem: &Pem) -> Result<Vec<u8>> {
if pem.tag() == "CERTIFICATE" {
// Handle X509 certificate
let (_, cert) = parse_x509_certificate(pem.contents())?;
let public_key = cert.public_key();
Ok(public_key.subject_public_key.data.to_vec())
} else {
// Handle public key directly
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
}
/// Create RSA JWK from certificate PEM
fn create_rsa_jwk_from_cert(pem: &Pem, key_hash: &[u8]) -> Result<Jwk> {
let public_key = Self::extract_spki_from_pem(pem)?;
// Extract RSA parameters (n, e) from RSA public key DER data
let rsa_key = RsaPublicKey::from_pkcs1_der(&public_key)?;
let n = rsa_key.n().to_bytes_be();
let e = rsa_key.e().to_bytes_be();
Ok(Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::RS256),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
x509_sha256_fingerprint: None::<String>,
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
key_type: RSAKeyType::RSA,
n: URL_SAFE_NO_PAD.encode(n),
e: URL_SAFE_NO_PAD.encode(e),
}),
})
}
// END HADRON
/// Create a JSON Web Key Set. This ideally matches the way we create a JWKS
/// from the production control plane.
fn create_jwks_from_pem(pem: &Pem) -> Result<JwkSet> {
let public_key = Self::extract_spki_from_pem(pem)?;
let spki: SubjectPublicKeyInfoRef = SubjectPublicKeyInfo::from_der(pem.contents())?;
let public_key = spki.subject_public_key.raw_bytes();
let mut hasher = Sha256::new();
hasher.update(&public_key);
hasher.update(public_key);
let key_hash = hasher.finalize();
// BEGIN HADRON
if pem.tag() == "CERTIFICATE" {
// Assume RSA if we are parsing keys from a certificate.
let jwk = Self::create_rsa_jwk_from_cert(pem, &key_hash)?;
return Ok(JwkSet { keys: vec![jwk] });
}
// END HADRON
Ok(JwkSet {
keys: vec![Jwk {
common: CommonParameters {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -239,7 +186,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: URL_SAFE_NO_PAD.encode(public_key),
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
}),
}],
})
@@ -294,7 +241,7 @@ impl ComputeControlPlane {
drop_subscriptions_before_start,
grpc,
reconfigure_concurrency: 1,
features: vec![],
features: vec![ComputeFeature::ActivityMonitorExperimental],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
@@ -316,7 +263,7 @@ impl ComputeControlPlane {
skip_pg_catalog_updates,
drop_subscriptions_before_start,
reconfigure_concurrency: 1,
features: vec![],
features: vec![ComputeFeature::ActivityMonitorExperimental],
cluster: None,
compute_ctl_config,
privileged_role_name,

View File

@@ -2,7 +2,6 @@ use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
@@ -17,22 +16,15 @@ pub struct EndpointStorage {
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub auth_type: AuthType,
}
impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
let auth_type = match env.token_auth_type {
AuthType::HadronJWT => AuthType::HadronJWT,
AuthType::NeonJWT | AuthType::Trust => AuthType::NeonJWT,
};
EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
auth_type,
}
}
@@ -54,14 +46,12 @@ impl EndpointStorage {
pemfile: Utf8PathBuf,
local_path: Utf8PathBuf,
r#type: String,
auth_type: AuthType,
}
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
auth_type: self.auth_type,
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
std::fs::write(self.config_path(), serde_json::to_string(&cfg)?)

View File

@@ -18,7 +18,7 @@ use postgres_backend::AuthType;
use reqwest::{Certificate, Url};
use safekeeper_api::PgMajorVersion;
use serde::{Deserialize, Serialize};
use utils::auth::{encode_from_key_file, encode_hadron_token};
use utils::auth::encode_from_key_file;
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::broker::StorageBroker;
@@ -60,9 +60,6 @@ pub struct LocalEnv {
// --tenant_id is not explicitly specified.
pub default_tenant_id: Option<TenantId>,
// The type of tokens to use for authentication in the test environment. Determines
// the type of key pairs and tokens generated in the test.
pub token_auth_type: AuthType,
// used to issue tokens during e.g pg start
pub private_key_path: PathBuf,
/// Path to environment's public key
@@ -108,7 +105,6 @@ pub struct OnDiskConfig {
pub pg_distrib_dir: PathBuf,
pub neon_distrib_dir: PathBuf,
pub default_tenant_id: Option<TenantId>,
pub token_auth_type: Option<AuthType>,
pub private_key_path: PathBuf,
pub public_key_path: PathBuf,
pub broker: NeonBroker,
@@ -157,7 +153,6 @@ pub struct NeonLocalInitConf {
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
pub auth_token_type: AuthType,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -379,7 +374,7 @@ pub struct SafekeeperConf {
pub sync: bool,
pub remote_storage: Option<String>,
pub backup_threads: Option<u32>,
pub auth_type: AuthType,
pub auth_enabled: bool,
pub listen_addr: Option<String>,
}
@@ -394,7 +389,7 @@ impl Default for SafekeeperConf {
sync: true,
remote_storage: None,
backup_threads: None,
auth_type: AuthType::Trust,
auth_enabled: false,
listen_addr: None,
}
}
@@ -668,7 +663,6 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type,
private_key_path,
public_key_path,
broker,
@@ -687,7 +681,6 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id,
token_auth_type: token_auth_type.unwrap_or(AuthType::NeonJWT),
private_key_path,
public_key_path,
broker,
@@ -803,7 +796,6 @@ impl LocalEnv {
pg_distrib_dir: self.pg_distrib_dir.clone(),
neon_distrib_dir: self.neon_distrib_dir.clone(),
default_tenant_id: self.default_tenant_id,
token_auth_type: Some(self.token_auth_type),
private_key_path: self.private_key_path.clone(),
public_key_path: self.public_key_path.clone(),
broker: self.broker.clone(),
@@ -833,18 +825,8 @@ impl LocalEnv {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn generate_auth_token<S: Serialize>(&self, claims: &S) -> anyhow::Result<String> {
match self.token_auth_type {
AuthType::NeonJWT => {
let key_data = self.read_private_key()?;
encode_from_key_file(claims, &key_data)
}
AuthType::HadronJWT => {
let private_key_path = self.get_private_key_path();
let key_data = fs::read(private_key_path)?;
encode_hadron_token(claims, &key_data)
}
_ => panic!("unsupported token auth type {:?}", self.token_auth_type),
}
let key = self.read_private_key()?;
encode_from_key_file(claims, &key)
}
/// Get the path to the private key.
@@ -933,7 +915,6 @@ impl LocalEnv {
generate_local_ssl_certs,
control_plane_hooks_api,
endpoint_storage,
auth_token_type,
} = conf;
// Find postgres binaries.
@@ -962,7 +943,6 @@ impl LocalEnv {
generate_auth_keys(
base_path.join("auth_private_key.pem").as_path(),
base_path.join("auth_public_key.pem").as_path(),
auth_token_type,
)
.context("generate auth keys")?;
let private_key_path = PathBuf::from("auth_private_key.pem");
@@ -976,7 +956,6 @@ impl LocalEnv {
pg_distrib_dir,
neon_distrib_dir,
default_tenant_id: Some(default_tenant_id),
token_auth_type: auth_token_type,
private_key_path,
public_key_path,
broker,
@@ -1056,63 +1035,39 @@ pub fn base_path() -> PathBuf {
}
/// Generate a public/private key pair for JWT authentication
fn generate_auth_keys(
private_key_path: &Path,
public_key_path: &Path,
auth_type: AuthType,
) -> anyhow::Result<()> {
if auth_type == AuthType::NeonJWT {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
} else if auth_type == AuthType::HadronJWT {
// Generate the RSA key pair. Note that the public key is embedded in an X509 certificate.
//
// openssl req -x509 -newkey rsa:4096 -keyout auth_private_key.pem -out auth_public_key.pem -nodes -subj "/CN=eng-brickstore@databricks.com"
let keygen_output = Command::new("openssl")
.arg("req")
.args(["-x509", "-newkey", "rsa:4096", "-sha256"])
.args(["-keyout", private_key_path.to_str().unwrap()])
.args(["-out", public_key_path.to_str().unwrap()])
.args(["-nodes"])
.args(["-subj", "/CN=eng-brickstore@databricks.com"])
.output()
.context("Failed to generate RSA key pair for Hadron token auth")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
fn generate_auth_keys(private_key_path: &Path, public_key_path: &Path) -> anyhow::Result<()> {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
Ok(())

View File

@@ -73,7 +73,7 @@ impl PageServerNode {
{
match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT | AuthType::HadronJWT => Some(
AuthType::NeonJWT => Some(
env.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap(),
),
@@ -117,10 +117,7 @@ impl PageServerNode {
// Storage controller uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
// Note: In Hadron the "control plane" is HCC. HCC does not require a token on the trusted port PS connects
// to, so we do not need to set any tokens when using HadronJWT. In the future we may consider using mTLS
// instead of JWT for HTTP auth.
if matches!(conf.http_auth_type, AuthType::NeonJWT | AuthType::HadronJWT) {
if matches!(conf.http_auth_type, AuthType::NeonJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::GenerationsApi))
@@ -135,8 +132,7 @@ impl PageServerNode {
}
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
.contains(&AuthType::NeonJWT)
{
// Keys are generated in the toplevel repo dir, pageservers' workdirs
// are one level below that, so refer to keys with ../

View File

@@ -13,7 +13,6 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use postgres_backend::AuthType;
use postgres_connection::PgConnectionConfig;
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
@@ -111,7 +110,7 @@ impl SafekeeperNode {
}
// Generate a token file for authentication with other safekeepers
if self.conf.auth_type != AuthType::Trust {
if self.conf.auth_enabled {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::SafekeeperData))?;
@@ -157,7 +156,7 @@ impl SafekeeperNode {
"--id".to_owned(),
id_string,
"--listen-pg".to_owned(),
listen_pg.clone(),
listen_pg,
"--listen-http".to_owned(),
listen_http,
"--availability-zone".to_owned(),
@@ -187,11 +186,7 @@ impl SafekeeperNode {
}
let key_path = self.env.base_data_dir.join("auth_public_key.pem");
if self.conf.auth_type != AuthType::Trust {
args.extend([
"--token-auth-type".to_owned(),
self.conf.auth_type.to_string(),
]);
if self.conf.auth_enabled {
let key_path_string = key_path
.to_str()
.with_context(|| {
@@ -210,15 +205,6 @@ impl SafekeeperNode {
"--http-auth-public-key-path".to_owned(),
key_path_string.clone(),
]);
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
if let Some(https_port) = self.conf.https_port {
@@ -231,14 +217,26 @@ impl SafekeeperNode {
args.push(format!("--ssl-ca-file={}", ssl_ca_file.to_str().unwrap()));
}
if self.conf.auth_enabled {
let token_path = self.datadir_path().join("peer_jwt_token");
let token_path_str = token_path
.to_str()
.with_context(|| {
format!("Token path {token_path:?} cannot be represented as a unicode string")
})?
.to_owned();
args.extend(["--auth-token-path".to_owned(), token_path_str]);
}
args.extend_from_slice(extra_opts);
let env_variables = Vec::new();
background_process::start_process(
&format!("safekeeper-{id}"),
&datadir,
&self.env.safekeeper_bin(),
&args,
self.safekeeper_env_variables()?,
env_variables,
background_process::InitialPidFile::Expect(self.pid_file()),
retry_timeout,
|| async {
@@ -252,11 +250,6 @@ impl SafekeeperNode {
.await
}
fn safekeeper_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// TODO: remove me
Ok(vec![])
}
///
/// Stop the server.
///

View File

@@ -30,14 +30,14 @@ use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
use url::Url;
use utils::auth::{Claims, Scope, encode_from_key_file, encode_hadron_token};
use utils::auth::{Claims, Scope, encode_from_key_file};
use utils::id::{NodeId, TenantId};
use whoami::username;
pub struct StorageController {
env: LocalEnv,
private_key: Option<StorageControllerPrivateKey>,
public_key: Option<StorageControllerPublicKey>,
private_key: Option<Pem>,
public_key: Option<Pem>,
client: reqwest::Client,
config: NeonStorageControllerConf,
@@ -108,25 +108,6 @@ pub struct InspectResponse {
pub attachment: Option<(u32, NodeId)>,
}
enum StorageControllerPublicKey {
RawPublicKey(Pem),
PublicKeyCertPath(Utf8PathBuf),
}
enum StorageControllerPrivateKey {
EdPrivateKey(Pem),
HadronPrivateKey(Utf8PathBuf, Vec<u8>),
}
impl StorageControllerPrivateKey {
pub fn encode_token(&self, claims: &Claims) -> anyhow::Result<String> {
match self {
Self::EdPrivateKey(key_data) => encode_from_key_file(claims, key_data),
Self::HadronPrivateKey(_, key_data) => encode_hadron_token(claims, key_data),
}
}
}
impl StorageController {
pub fn from_env(env: &LocalEnv) -> Self {
// Assume all pageservers have symmetric auth configuration: this service
@@ -171,30 +152,7 @@ impl StorageController {
)
.expect("Failed to parse PEM file")
};
(
Some(StorageControllerPrivateKey::EdPrivateKey(private_key)),
Some(StorageControllerPublicKey::RawPublicKey(public_key)),
)
}
AuthType::HadronJWT => {
let private_key_path = env.get_private_key_path();
let private_key =
fs::read(private_key_path.clone()).expect("failed to read private key");
// If pageserver auth is enabled, this implicitly enables auth for this service,
// using the same credentials.
let public_key_path =
camino::Utf8PathBuf::try_from(env.base_data_dir.join("auth_public_key.pem"))
.unwrap();
(
Some(StorageControllerPrivateKey::HadronPrivateKey(
camino::Utf8PathBuf::try_from(private_key_path).unwrap(),
private_key,
)),
Some(StorageControllerPublicKey::PublicKeyCertPath(
public_key_path,
)),
)
(Some(private_key), Some(public_key))
}
};
@@ -617,38 +575,23 @@ impl StorageController {
if let Some(private_key) = &self.private_key {
let claims = Claims::new(None, Scope::PageServerApi);
if let StorageControllerPrivateKey::HadronPrivateKey(key_path, _) = private_key {
args.push(format!("--private-key-path={key_path}"));
}
// We are setting all JWT tokens for Hadron as well in this test to avoid bifurcation between Neon and
// Hadron test cases. In production we do not need to set this as HTTP auth is not enabled on the
// pageserver. We use network segmentation to ensure that only trusted components can talk to
// pageserver's http port
let jwt_token = private_key.encode_token(&claims)?;
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
args.push(format!("--jwt-token={jwt_token}"));
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token = private_key
.encode_token(&peer_claims)
let peer_jwt_token = encode_from_key_file(&peer_claims, private_key)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token = private_key
.encode_token(&claims)
.expect("failed to generate jwt token");
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
if let Some(public_key) = &self.public_key {
match public_key {
StorageControllerPublicKey::RawPublicKey(public_key) => {
args.push(format!("--public-key=\"{public_key}\""));
}
StorageControllerPublicKey::PublicKeyCertPath(public_key_path) => {
args.push(format!("--public-key-cert-path={public_key_path}"));
}
}
args.push(format!("--public-key=\"{public_key}\""));
}
if let Some(control_plane_hooks_api) = &self.env.control_plane_hooks_api {
@@ -689,13 +632,7 @@ impl StorageController {
self.env.base_data_dir.display()
));
if self
.env
.safekeepers
.iter()
.any(|sk| sk.auth_type != AuthType::Trust)
&& self.private_key.is_none()
{
if self.env.safekeepers.iter().any(|sk| sk.auth_enabled) && self.private_key.is_none() {
anyhow::bail!("Safekeeper set up for auth but no private key specified");
}
@@ -910,7 +847,7 @@ impl StorageController {
println!("Getting claims for path {path}");
if let Some(required_claims) = Self::get_claims_for_path(&path)? {
println!("Got claims {required_claims:?} for path {path}");
let jwt_token = private_key.encode_token(&required_claims)?;
let jwt_token = encode_from_key_file(&required_claims, private_key)?;
builder = builder.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {jwt_token}"),

View File

@@ -303,6 +303,13 @@ enum Command {
#[arg(long, required = true, value_delimiter = ',')]
new_sk_set: Vec<NodeId>,
},
/// Abort ongoing safekeeper migration.
TimelineSafekeeperMigrateAbort {
#[arg(long)]
tenant_id: TenantId,
#[arg(long)]
timeline_id: TimelineId,
},
}
#[derive(Parser)]
@@ -1396,6 +1403,17 @@ async fn main() -> anyhow::Result<()> {
)
.await?;
}
Command::TimelineSafekeeperMigrateAbort {
tenant_id,
timeline_id,
} => {
let path =
format!("v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate_abort");
storcon_client
.dispatch::<(), ()>(Method::POST, path, None)
.await?;
}
}
Ok(())

View File

@@ -46,7 +46,6 @@ allow = [
"ISC",
"MIT",
"MPL-2.0",
"OpenSSL",
"Unicode-3.0",
]
confidence-threshold = 0.8

View File

@@ -20,7 +20,6 @@ tokio.workspace = true
tracing.workspace = true
utils = { path = "../libs/utils", default-features = false }
workspace_hack.workspace = true
postgres_backend.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
http-body-util.workspace = true

View File

@@ -206,16 +206,12 @@ mod tests {
use axum::{body::Body, extract::Request, response::Response};
use http_body_util::BodyExt;
use itertools::iproduct;
use jsonwebtoken::DecodingKey;
use std::env::var;
use std::sync::Arc;
use std::time::Duration;
use test_log::test as testlog;
use tower::{Service, util::ServiceExt};
use utils::{
auth::JwtAuth,
id::{TenantId, TimelineId},
};
use utils::id::{TenantId, TimelineId};
// see libs/remote_storage/tests/test_real_s3.rs
const REAL_S3_ENV: &str = "ENABLE_REAL_S3_REMOTE_STORAGE";
@@ -255,9 +251,7 @@ mod tests {
};
let proxy = Storage {
auth: JwtAuth::new(vec![
DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap(),
]),
auth: endpoint_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
storage,
cancel: cancel.clone(),
max_upload_file_limit: usize::MAX,
@@ -358,7 +352,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
@@ -507,7 +501,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::EdDSA);
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}

View File

@@ -7,6 +7,7 @@ use axum::{RequestPartsExt, http::StatusCode, http::request::Parts};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use camino::Utf8PathBuf;
use jsonwebtoken::{DecodingKey, Validation};
use remote_storage::{GenericRemoteStorage, RemotePath};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@@ -14,9 +15,28 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::auth::JwtAuth;
use utils::id::{EndpointId, TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
decoding_key: DecodingKey,
validation: Validation,
}
pub const VALIDATION_ALGO: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::EdDSA;
impl JwtAuth {
pub fn new(key: &[u8]) -> Result<Self> {
Ok(Self {
decoding_key: DecodingKey::from_ed_pem(key)?,
validation: Validation::new(VALIDATION_ALGO),
})
}
pub fn decode<T: serde::de::DeserializeOwned>(&self, token: &str) -> Result<T> {
Ok(jsonwebtoken::decode(token, &self.decoding_key, &self.validation).map(|t| t.claims)?)
}
}
fn normalize_key(key: &str) -> StdResult<Utf8PathBuf, String> {
let key = clean_utf8(&Utf8PathBuf::from(key));
if key.starts_with("..") || key == "." || key == "/" {
@@ -137,8 +157,7 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
let claims: EndpointStorageClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?
.claims;
.map_err(|e| bad_request(e, "decoding token"))?;
// Read paths may have different endpoint ids. For readonly -> readwrite replica
// prewarming, endpoint must read other endpoint's data.
@@ -205,8 +224,7 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
let claims: DeletePrefixClaims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?
.claims;
.map_err(|e| bad_request(e, "invalid token"))?;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,

View File

@@ -5,10 +5,8 @@
mod app;
use anyhow::Context;
use clap::Parser;
use postgres_backend::AuthType;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::auth::JwtAuth;
use utils::logging;
//see set()
@@ -20,10 +18,6 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
const fn default_auth_type() -> AuthType {
AuthType::NeonJWT
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
@@ -45,8 +39,6 @@ struct Config {
storage_kind: remote_storage::TypedRemoteStorageKind,
#[serde(default = "max_upload_file_limit")]
max_upload_file_limit: usize,
#[serde(default = "default_auth_type")]
auth_type: AuthType,
}
#[tokio::main]
@@ -69,15 +61,10 @@ async fn main() -> anyhow::Result<()> {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
if config.auth_type == AuthType::Trust {
anyhow::bail!("Trust based auth is not supported");
}
let auth = match config.auth_type {
AuthType::NeonJWT => JwtAuth::from_key_path(&config.pemfile)?,
AuthType::HadronJWT => JwtAuth::from_cert_path(&config.pemfile)?,
AuthType::Trust => unreachable!(),
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = endpoint_storage::JwtAuth::new(&pemfile)?;
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());

View File

@@ -1,10 +1,9 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::fmt::Display;
use chrono::{DateTime, Utc};
use jsonwebtoken::jwk::JwkSet;
use serde::{Deserialize, Serialize, Serializer};
use std::fmt::Display;
use crate::privilege::Privilege;
use crate::spec::{ComputeSpec, Database, ExtVersion, PgIdent, Role};
@@ -49,7 +48,7 @@ pub struct ExtensionInstallResponse {
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
/// Default value when compute boots up.
@@ -59,7 +58,14 @@ pub enum LfcPrewarmState {
Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed,
Completed {
total: i32,
prewarmed: i32,
skipped: i32,
state_download_time_ms: u32,
uncompress_time_ms: u32,
prewarm_time_ms: u32,
},
/// Unexpected error happened during prewarming. Note, `Not Found 404`
/// response from the endpoint storage is explicitly excluded here
/// because it can normally happen on the first compute start,
@@ -84,7 +90,7 @@ impl Display for LfcPrewarmState {
match self {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Completed { .. } => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
LfcPrewarmState::Cancelled => f.write_str("Cancelled"),
@@ -92,26 +98,36 @@ impl Display for LfcPrewarmState {
}
}
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcOffloadState {
#[default]
NotOffloaded,
Offloading,
Completed,
Completed {
state_query_time_ms: u32,
compress_time_ms: u32,
state_upload_time_ms: u32,
},
Failed {
error: String,
},
/// LFC state was empty so it wasn't offloaded
Skipped,
}
#[derive(Serialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
/// Response of /promote
pub enum PromoteState {
NotPromoted,
Completed,
Failed { error: String },
Completed {
lsn_wait_time_ms: u32,
pg_promote_time_ms: u32,
reconfigure_time_ms: u32,
},
Failed {
error: String,
},
}
#[derive(Deserialize, Default, Debug)]

View File

@@ -705,10 +705,8 @@ pub fn check_permission_with(
check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
) -> Result<(), ApiError> {
match req.context::<Claims>() {
Some(claims) => Ok(check_permission(&claims).map_err(|err| {
tracing::info!("Authorization error: {err}");
ApiError::Forbidden("JWT authentication error".to_string())
})?),
Some(claims) => Ok(check_permission(&claims)
.map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
None => Ok(()), // claims is None because auth is disabled
}
}

View File

@@ -194,10 +194,6 @@ pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
// Similar to above but uses Hadron JWT. Hadron JWTs are slightly different in that:
// 1. Decoding keys are loaded from PEM-encoded X509 certificates instead of plain key files.
// 2. Signature algorithm is RSA-based (may change in the future).
HadronJWT,
}
impl FromStr for AuthType {
@@ -207,7 +203,6 @@ impl FromStr for AuthType {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
"HadronJWT" => Ok(Self::HadronJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
@@ -218,7 +213,6 @@ impl fmt::Display for AuthType {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
AuthType::HadronJWT => "HadronJWT",
})
}
}
@@ -619,10 +613,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(matches!(
self.auth_type,
AuthType::NeonJWT | AuthType::HadronJWT
));
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
@@ -721,7 +712,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT | AuthType::HadronJWT => {
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;

View File

@@ -19,7 +19,6 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
camino.workspace = true
camino-tempfile.workspace = true
chrono.workspace = true
diatomic-waker.workspace = true
git-version.workspace = true
@@ -29,7 +28,6 @@ fail.workspace = true
futures = { workspace = true }
jsonwebtoken.workspace = true
nix = { workspace = true, features = ["ioctl"] }
oid-registry.workspace = true
once_cell.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
@@ -50,12 +48,9 @@ tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
rustls-pemfile.workspace = true
rustls-pki-types.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true
x509-parser.workspace = true
pq_proto.workspace = true
postgres_connection.workspace = true
@@ -72,7 +67,6 @@ camino-tempfile.workspace = true
pprof.workspace = true
serde_assert.workspace = true
tokio = { workspace = true, features = ["test-util"] }
rcgen = { version = "=0.13.1", features = ["crypto", "aws_lc_rs"] }
[[bench]]
name = "benchmarks"

View File

@@ -1,9 +1,9 @@
// For details about authentication see docs/authentication.md
use std::borrow::Cow;
use std::fmt::Display;
use std::fs;
use std::sync::Arc;
use std::{borrow::Cow, io, path::Path};
use anyhow::Result;
use arc_swap::ArcSwap;
@@ -11,17 +11,14 @@ use camino::Utf8Path;
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use oid_registry::OID_PKCS1_RSAENCRYPTION;
use pem::Pem;
use rustls_pki_types::CertificateDer;
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use crate::id::TenantId;
/// Signature algorithms to use. We allow EdDSA and RSA/SHA-256.
/// Algorithm to use. We require EdDSA.
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
const HADRON_STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -98,14 +95,6 @@ impl Claims {
endpoint_id: None,
}
}
pub fn new_for_endpoint(endpoint_id: Uuid) -> Self {
Self {
tenant_id: None,
endpoint_id: Some(endpoint_id),
scope: Scope::TenantEndpoint,
}
}
}
pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
@@ -186,96 +175,6 @@ impl JwtAuth {
Ok(Self::new(decoding_keys))
}
// Helper function to parse a X509 certificate file and extract the RSA public keys from it as `DecodingKey`s.
// - `ceritificate_file_path`: the path to the certificate file. It must be a file, not a directory or anything else.
// Returns the successfully extracted decoding keys. Non-RSA keys and non-X509-parsable certificates are skipped.
// Multuple keys may be returned because a single file can contain multiple certificates.
fn extract_rsa_decoding_keys_from_certificate<P: AsRef<Path>>(
certificate_file_path: P,
) -> Result<Vec<DecodingKey>> {
let certs: io::Result<Vec<CertificateDer<'static>>> = rustls_pemfile::certs(
&mut io::BufReader::new(fs::File::open(certificate_file_path)?),
)
.collect();
Ok(certs?
.iter()
.filter_map(
|cert| match x509_parser::parse_x509_certificate(cert) {
Ok((_, cert)) => {
let public_key = cert.public_key();
// Note that we are just extracting the public key from the certificate, not the signature.
// So the algorithm is just the asymmetric crypto such as RSA, no hashes of or anything like
// that.
if *public_key.algorithm.oid() == OID_PKCS1_RSAENCRYPTION {
Some(DecodingKey::from_rsa_der(&public_key.subject_public_key.data))
} else {
tracing::warn!(
"Unsupported public key algorithm: {:?} found in certificate. Skipping.",
public_key.algorithm
);
None
}
}
Err(e) => {
tracing::warn!("Error parsing certificate: {}. Skipping.", e);
None
}
},
)
.collect())
}
/// Create a `JwtAuth` that can decode tokens using RSA public keys in X509 certificates from the given path.
/// - `cert_path`: the path to a directory or a file containing X509 certificates. If it is a directory, all files
/// under the first level of the directory will be inspected for certificates.
/// Returns the `JwtAuth` with the decoding keys extracted from the certificates, or error.
/// Used by Hadron.
pub fn from_cert_path(cert_path: &Utf8Path) -> Result<Self> {
tracing::info!(
"Loading public keys in certificates from path: {}",
cert_path
);
let mut decoding_keys = Vec::new();
let metadata = cert_path.metadata()?;
if metadata.is_dir() {
for entry in fs::read_dir(cert_path)? {
let path = entry?.path();
if !path.is_file() {
// Ignore directories (don't recurse)
continue;
}
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(path).unwrap_or_default(),
);
}
} else if metadata.is_file() {
decoding_keys.extend(
Self::extract_rsa_decoding_keys_from_certificate(cert_path).unwrap_or_default(),
);
} else {
anyhow::bail!("{cert_path} is neither a directory or a file")
}
if decoding_keys.is_empty() {
anyhow::bail!(
"Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected."
);
}
// Note that we need to create a `JwtAuth` with a different `validation` from the default one created by `new()` in this case
// because the `jsonwebtoken` crate requires that all algorithms in `validation.algorithms` belong to the same algorithm family
// (all RSA or all EdDSA).
let mut validation = Validation::default();
validation.algorithms = vec![HADRON_STORAGE_TOKEN_ALGORITHM];
validation.required_spec_claims = [].into();
Ok(Self {
validation,
decoding_keys,
})
}
pub fn from_key(key: String) -> Result<Self> {
Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
}
@@ -318,28 +217,8 @@ pub fn encode_from_key_file<S: Serialize>(claims: &S, pem: &Pem) -> Result<Strin
Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
}
/// Encode (i.e., sign) a Hadron auth token with the given claims and RSA private key. This is used
/// by HCC to sign tokens when deploying compute or returning the compute spec. The resulting token
/// is used by the compute node to authenticate with HCC and PS/SK.
pub fn encode_hadron_token<S: Serialize>(claims: &S, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
encode_hadron_token_with_encoding_key(claims, &key)
}
pub fn encode_hadron_token_with_encoding_key<S: Serialize>(
claims: &S,
encoding_key: &EncodingKey,
) -> Result<String> {
Ok(encode(
&Header::new(HADRON_STORAGE_TOKEN_ALGORITHM),
claims,
encoding_key,
)?)
}
#[cfg(test)]
mod tests {
use io::Write;
use std::str::FromStr;
use super::*;
@@ -364,8 +243,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_decode() {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
endpoint_id: None,
scope: Scope::Tenant,
endpoint_id: None,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -393,8 +272,8 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
fn test_encode() {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
endpoint_id: None,
scope: Scope::Tenant,
endpoint_id: None,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
@@ -408,72 +287,4 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
assert_eq!(decoded.claims, claims);
}
#[test]
fn test_decode_with_key_from_certificate() {
// Tests that we can sign (encode) a token with a RSA private key and verify (decode) it with the
// corresponding public key extracted from a certificate.
// Generate two RSA key pairs and create self-signed certificates with it.
let key_pair_1 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let key_pair_2 = rcgen::KeyPair::generate_for(&rcgen::PKCS_RSA_SHA256).unwrap();
let mut params = rcgen::CertificateParams::default();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "eng-brickstore@databricks.com");
let cert_1 = params.clone().self_signed(&key_pair_1).unwrap();
let cert_2 = params.self_signed(&key_pair_2).unwrap();
// Write the certificates and keys to a temporary dir.
let dir = camino_tempfile::tempdir().unwrap();
{
fs::File::create(dir.path().join("cert_1.pem"))
.unwrap()
.write_all(cert_1.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_1.pem"))
.unwrap()
.write_all(key_pair_1.serialize_pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("cert_2.pem"))
.unwrap()
.write_all(cert_2.pem().as_bytes())
.unwrap();
fs::File::create(dir.path().join("key_2.pem"))
.unwrap()
.write_all(key_pair_2.serialize_pem().as_bytes())
.unwrap();
}
// Instantiate a `JwtAuth` with the certificate path. The resulting `JwtAuth` should extract the RSA public
// keys out of the X509 certificates and use them as the decoding keys. Since we specified a directory, both
// X509 certificates will be loaded, but the private key files are skipped.
let auth = JwtAuth::from_cert_path(dir.path()).unwrap();
assert_eq!(auth.decoding_keys.len(), 2);
// Also create a `JwtAuth`, specifying a single certificate file for it to get the decoding key from.
let auth_cert_1 = JwtAuth::from_cert_path(&dir.path().join("cert_1.pem")).unwrap();
assert_eq!(auth_cert_1.decoding_keys.len(), 1);
// Encode tokens with some claims.
let claims = Claims {
tenant_id: Some(TenantId::generate()),
endpoint_id: None,
scope: Scope::Tenant,
};
let encoded_1 =
encode_hadron_token(&claims, key_pair_1.serialize_pem().as_bytes()).unwrap();
let encoded_2 =
encode_hadron_token(&claims, key_pair_2.serialize_pem().as_bytes()).unwrap();
// Verify that we can decode the token with matching decoding keys (decoding also verifies the signature).
assert_eq!(auth.decode::<Claims>(&encoded_1).unwrap().claims, claims);
assert_eq!(auth.decode::<Claims>(&encoded_2).unwrap().claims, claims);
assert_eq!(
auth_cert_1.decode::<Claims>(&encoded_1).unwrap().claims,
claims
);
// Verify that the token cannot be decoded with a mismatched decode key.
assert!(auth_cert_1.decode::<Claims>(&encoded_2).is_err());
}
}

View File

@@ -458,37 +458,25 @@ fn start_pageserver(
let http_auth;
let pg_auth;
let grpc_auth;
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.iter()
.any(|auth_type| *auth_type == AuthType::NeonJWT || *auth_type == AuthType::HadronJWT)
{
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) {
// unwrap is ok because check is performed when creating config, so path is set and exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
let use_hadron_jwt = conf.http_auth_type == AuthType::HadronJWT
|| conf.pg_auth_type == AuthType::HadronJWT
|| conf.grpc_auth_type == AuthType::HadronJWT;
let jwt_auth = if use_hadron_jwt {
// To validate Hadron JWTs we need to extract decoding keys from X509 certificates.
JwtAuth::from_cert_path(key_path)?
} else {
JwtAuth::from_key_path(key_path)?
};
let jwt_auth = JwtAuth::from_key_path(key_path)?;
let auth: Arc<SwappableJwtAuth> = Arc::new(SwappableJwtAuth::new(jwt_auth));
http_auth = match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
AuthType::NeonJWT => Some(auth.clone()),
};
pg_auth = match conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth.clone()),
AuthType::NeonJWT => Some(auth.clone()),
};
grpc_auth = match conf.grpc_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT | AuthType::HadronJWT => Some(auth),
AuthType::NeonJWT => Some(auth),
};
} else {
http_auth = None;

View File

@@ -629,13 +629,6 @@ impl PageServerConf {
}
};
let auth_types = [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type];
if auth_types.contains(&AuthType::NeonJWT) && auth_types.contains(&AuthType::HadronJWT) {
return Err(anyhow::anyhow!(
"Mixing neon and hadron style JWT tokens is not supported"
));
}
Ok(conf)
}

View File

@@ -44,7 +44,6 @@ use pageserver_api::models::{
TopTenantShardItem, TopTenantShardsRequest, TopTenantShardsResponse,
};
use pageserver_api::shard::{ShardCount, TenantShardId};
use postgres_backend::AuthType;
use postgres_ffi::PgMajorVersion;
use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError};
use scopeguard::defer;
@@ -56,7 +55,6 @@ use tokio::time::Instant;
use tokio_util::io::StreamReader;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::auth::JwtAuth;
use utils::auth::SwappableJwtAuth;
use utils::generation::Generation;
use utils::id::{TenantId, TimelineId};
@@ -562,10 +560,6 @@ async fn reload_auth_validation_keys_handler(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
// Note to Bricksters: This API returns 400 if HTTP auth is not enabled. This is because `state.auth` is only
// determined by HTTP auth.
// TODO(william.huang): In practice both HTTP and PG auth point to the same SwappableJwtAuth object. Refactor
// this code so that we can swap out the underlying shared auth object even if HTTP auth is None.
check_permission(&request, None)?;
let config = get_config(&request);
let state = get_state(&request);
@@ -576,12 +570,7 @@ async fn reload_auth_validation_keys_handler(
let key_path = config.auth_validation_public_key_path.as_ref().unwrap();
info!("Reloading public key(s) for verifying JWT tokens from {key_path:?}");
let new_jwt_auth = if config.http_auth_type == AuthType::HadronJWT {
JwtAuth::from_cert_path(key_path)
} else {
JwtAuth::from_key_path(key_path)
};
match new_jwt_auth {
match utils::auth::JwtAuth::from_key_path(key_path) {
Ok(new_auth) => {
shared_auth.swap(new_auth);
json_response(StatusCode::OK, ())

View File

@@ -15,7 +15,6 @@ use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use http_utils::tls_certs::ReloadingCertificateResolver;
use metrics::set_build_info_metric;
use postgres_backend::AuthType;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT,
@@ -110,15 +109,10 @@ struct Args {
/// Listen https endpoint for management and metrics in the form host:port.
#[arg(long, default_value = None)]
listen_https: Option<String>,
/// Advertised endpoint to PS for receiving/sending WAL in the form host:port. If not
/// Advertised endpoint for receiving/sending WAL in the form host:port. If not
/// specified, listen_pg is used to advertise instead.
#[arg(long, default_value = None)]
advertise_pg: Option<String>,
/// Advertised endpoint to compute for receiving/sending WAL in the form host:port.
/// Required if --hcc-base-url is specified.
// TODO(vlad): pull in hcc-base-url too
#[arg(long, default_value = None)]
advertise_pg_tenant_only: Option<String>,
/// Availability zone of the safekeeper.
#[arg(long)]
availability_zone: Option<String>,
@@ -170,12 +164,6 @@ struct Args {
/// WAL backup horizon.
#[arg(long)]
disable_wal_backup: bool,
/// Token authentication type. Allowed values are "NeonJWT" and "HadronJWT". Any specified value only takes effect if
/// --pg-auth-public-key-path, --pg-tenant-only-auth-public-key-path, or --http-auth-public-key-path is specified.
/// NeonJWT: Decoding keys are loaded from plain public key files in the specified key path.
/// HadronJWT: Decoding keys are loaded from X509 certificates in the specified key path.
#[arg(long, verbatim_doc_comment, default_value = "NeonJWT")]
token_auth_type: AuthType,
/// If given, enables auth on incoming connections to WAL service endpoint
/// (--listen-pg). Value specifies path to a .pem public key used for
/// validations of JWT tokens. Empty string is allowed and means disabling
@@ -373,19 +361,9 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg auth JWT key from {path}");
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
}
};
let pg_tenant_only_auth = match args.pg_tenant_only_auth_public_key_path.as_ref() {
@@ -395,19 +373,9 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading pg tenant only auth JWT key from {path}");
match args.token_auth_type {
AuthType::NeonJWT => Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
)),
AuthType::HadronJWT => Some(Arc::new(
JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
)),
_ => panic!(
"AuthType {auth_type} is not allowed when --pg-tenant-only-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
}
Some(Arc::new(
JwtAuth::from_key_path(path).context("failed to load the auth key")?,
))
}
};
let http_auth = match args.http_auth_public_key_path.as_ref() {
@@ -417,17 +385,7 @@ async fn main() -> anyhow::Result<()> {
}
Some(path) => {
info!("loading http auth JWT key(s) from {path}");
let jwt_auth = match args.token_auth_type {
AuthType::NeonJWT => {
JwtAuth::from_key_path(path).context("failed to load the auth key")?
}
AuthType::HadronJWT => JwtAuth::from_cert_path(path)
.context("failed to load auth keys from certificates")?,
_ => panic!(
"AuthType {auth_type} is not allowed when --http-auth-public-key-path is specified",
auth_type = args.token_auth_type
),
};
let jwt_auth = JwtAuth::from_key_path(path).context("failed to load the auth key")?;
Some(Arc::new(SwappableJwtAuth::new(jwt_auth)))
}
};
@@ -476,7 +434,6 @@ async fn main() -> anyhow::Result<()> {
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
auth_type: args.token_auth_type,
pg_auth,
pg_tenant_only_auth,
http_auth,
@@ -500,7 +457,7 @@ async fn main() -> anyhow::Result<()> {
enable_tls_wal_service_api: args.enable_tls_wal_service_api,
force_metric_collection_on_scrape: args.force_metric_collection_on_scrape,
/* BEGIN_HADRON */
advertise_pg_addr_tenant_only: args.advertise_pg_tenant_only,
advertise_pg_addr_tenant_only: None,
enable_pull_timeline_on_startup: args.enable_pull_timeline_on_startup,
hcc_base_url: None,
global_disk_check_interval: args.global_disk_check_interval,

View File

@@ -1,7 +1,6 @@
#![deny(clippy::undocumented_unsafe_blocks)]
extern crate hyper0 as hyper;
use postgres_backend::AuthType;
use std::time::Duration;
@@ -129,7 +128,6 @@ pub struct SafeKeeperConf {
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub auth_type: AuthType,
pub pg_auth: Option<Arc<JwtAuth>>,
pub pg_tenant_only_auth: Option<Arc<JwtAuth>>,
pub http_auth: Option<Arc<SwappableJwtAuth>>,
@@ -175,7 +173,6 @@ impl SafeKeeperConf {
peer_recovery_enabled: true,
wal_backup_enabled: true,
backup_parallel_jobs: 1,
auth_type: AuthType::HadronJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -103,7 +103,7 @@ async fn handle_socket(
};
let auth_type = match auth_key {
None => AuthType::Trust,
Some(_) => conf.auth_type,
Some(_) => AuthType::NeonJWT,
};
let auth_pair = auth_key.map(|key| (allowed_auth_scope, key));
let mut conn_handler = SafekeeperPostgresHandler::new(

View File

@@ -14,7 +14,6 @@ use desim::network::TCP;
use desim::node_os::NodeOs;
use desim::proto::{AnyMessage, NetEvent, NodeEvent};
use http::Uri;
use postgres_backend::AuthType;
use safekeeper::SafeKeeperConf;
use safekeeper::safekeeper::{
ProposerAcceptorMessage, SK_PROTO_VERSION_3, SafeKeeper, UNKNOWN_SERVER_VERSION,
@@ -170,7 +169,6 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
availability_zone: None,
peer_recovery_enabled: false,
backup_parallel_jobs: 0,
auth_type: AuthType::NeonJWT,
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,

View File

@@ -31,7 +31,6 @@ humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
json-structural-diff.workspace = true
jsonwebtoken.workspace = true
lasso.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
@@ -75,4 +74,4 @@ http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -9,6 +9,7 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
Ok(())
}
#[allow(dead_code)]
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
if claims.scope != Scope::TenantEndpoint {
return Err(AuthError("Scope mismatch. Permission denied".into()));

View File

@@ -1,52 +0,0 @@
use anyhow::{Result, bail};
use camino::Utf8Path;
use jsonwebtoken::EncodingKey;
use std::fs;
use utils::{
auth::{Claims, Scope, encode_hadron_token_with_encoding_key},
id::TenantId,
};
use uuid::Uuid;
pub struct HadronTokenGenerator {
encoding_key: EncodingKey,
}
impl HadronTokenGenerator {
pub fn new(path: &Utf8Path) -> anyhow::Result<Self> {
let key_data = match fs::read(path) {
Ok(ok) => ok,
Err(e) => bail!("Error reading private key file {path:?}. Error: {e}"),
};
let encoding_key = match EncodingKey::from_rsa_pem(&key_data) {
Ok(ok) => ok,
Err(e) => {
bail!("Error reading private key file {path:?} as RSA private key. Error: {e}")
}
};
Ok(Self { encoding_key })
}
pub fn generate_tenant_scope_token(&self, tenant_id: TenantId) -> Result<String> {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
self.internal_encode_token(&claims)
}
pub fn generate_tenant_endpoint_scope_token(&self, endpoint_id: Uuid) -> Result<String> {
let claims = Claims::new_for_endpoint(endpoint_id);
self.internal_encode_token(&claims)
}
pub fn generate_ps_sk_auth_token(&self) -> Result<String> {
let claims = Claims {
tenant_id: None,
endpoint_id: None,
scope: Scope::SafekeeperData,
};
self.internal_encode_token(&claims)
}
fn internal_encode_token(&self, claims: &Claims) -> Result<String> {
encode_hadron_token_with_encoding_key(claims, &self.encoding_key)
}
}

View File

@@ -40,7 +40,6 @@ use tokio_util::sync::CancellationToken;
use tracing::warn;
use utils::auth::{Scope, SwappableJwtAuth};
use utils::id::{NodeId, TenantId, TimelineId};
use uuid::Uuid;
use crate::http;
use crate::metrics::{
@@ -645,6 +644,7 @@ async fn handle_tenant_timeline_safekeeper_migrate(
req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
// TODO(diko): it's not PS operation, there should be a different permission scope.
check_permissions(&req, Scope::PageServerApi)?;
maybe_rate_limit(&req, tenant_id).await;
@@ -666,6 +666,23 @@ async fn handle_tenant_timeline_safekeeper_migrate(
json_response(StatusCode::OK, ())
}
async fn handle_tenant_timeline_safekeeper_migrate_abort(
service: Arc<Service>,
req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?;
// TODO(diko): it's not PS operation, there should be a different permission scope.
check_permissions(&req, Scope::PageServerApi)?;
maybe_rate_limit(&req, tenant_id).await;
service
.tenant_timeline_safekeeper_migrate_abort(tenant_id, timeline_id)
.await?;
json_response(StatusCode::OK, ())
}
async fn handle_tenant_timeline_lsn_lease(
service: Arc<Service>,
req: Request<Body>,
@@ -1802,23 +1819,6 @@ fn check_permissions(request: &Request<Body>, required_scope: Scope) -> Result<(
}
})
}
/// Similar to `check_permissions()` above, but checks for TenantEndpoint scope specifically. Used by the compute spec-fetch API.
/// Access by Admin-scope tokens is also permitted.
/// TODO(william.huang): Merge with the previous function by refactoring `Scope` to make it carry the dependent arguments.
/// E.g., `Scope::TenantEndpoint(EndpointId)`, `Scope::Tenant(TenantId)`, etc.
#[allow(unused)]
fn check_endpoint_permission(request: &Request<Body>, endpoint_id: Uuid) -> Result<(), ApiError> {
check_permission_with(
request,
|claims| match crate::auth::check_endpoint_permission(claims, endpoint_id) {
Err(e) => match crate::auth::check_permission(claims, Scope::Admin) {
Ok(()) => Ok(()),
Err(_) => Err(e),
},
Ok(()) => Ok(()),
},
)
}
#[derive(Clone, Debug)]
struct RequestMeta {
@@ -2629,6 +2629,16 @@ pub fn make_router(
)
},
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/safekeeper_migrate_abort",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_safekeeper_migrate_abort,
RequestName("v1_tenant_timeline_safekeeper_migrate_abort"),
)
},
)
// LSN lease passthrough to all shards
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/lsn_lease",

View File

@@ -6,7 +6,6 @@ extern crate hyper0 as hyper;
mod auth;
mod background_node_operations;
mod compute_hook;
pub mod hadron_token;
pub mod hadron_utils;
mod heartbeater;
pub mod http;

View File

@@ -14,7 +14,6 @@ use metrics::BuildInfo;
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::hadron_token::HadronTokenGenerator;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::Persistence;
@@ -71,26 +70,10 @@ struct Cli {
#[arg(long)]
listen_https: Option<std::net::SocketAddr>,
/// PEM-encoded public key string for JWT authentication of clients.
/// Public key for JWT authentication of clients
#[arg(long)]
public_key: Option<String>,
/// Path to public key certificates used for JWT authentiation of clients.
/// Only one of `public_key` and `public_key_cert_path` should be set.
/// `public_key` or `public_key_cert_path` can point to either a file or a directory.
/// When pointed to a directory, public keys in all files in the first level of
/// the directory (i.e., no subdirectories) will be loaded.
#[arg(long)]
public_key_cert_path: Option<Utf8PathBuf>,
/// Path to the file containing the private key used to generate JWTs for client
/// authentication. The file should contain a single PEM-encoded private key.
/// The HCC uses this key to sign JWTs handed out to other components.
/// Note that unlike the `public_key` and `public_key_cert_path` args above,
/// `private_key_path` must specify a file path, not a directory.
#[arg(long)]
private_key_path: Option<Utf8PathBuf>,
/// Token for authenticating this service with the pageservers it controls
#[arg(long)]
jwt_token: Option<String>,
@@ -273,7 +256,6 @@ struct Secrets {
safekeeper_jwt_token: Option<String>,
control_plane_jwt_token: Option<String>,
peer_jwt_token: Option<String>,
token_generator: Option<HadronTokenGenerator>,
}
const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG";
@@ -299,16 +281,7 @@ impl Secrets {
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV) {
Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?),
None => {
if let Some(path) = args.public_key_cert_path.as_ref() {
Some(
JwtAuth::from_cert_path(path)
.context("Loading public key from certificates")?,
)
} else {
None
}
}
None => None,
};
let this = Self {
@@ -327,11 +300,6 @@ impl Secrets {
Self::CONTROL_PLANE_JWT_TOKEN_ENV,
),
peer_jwt_token: Self::load_secret(&args.peer_jwt_token, Self::PEER_JWT_TOKEN_ENV),
token_generator: args
.private_key_path
.as_ref()
.map(|path| HadronTokenGenerator::new(path))
.transpose()?,
};
Ok(this)
@@ -521,12 +489,12 @@ async fn async_main() -> anyhow::Result<()> {
let persistence = Arc::new(Persistence::new(secrets.database_url).await);
let service = Service::spawn(config, persistence.clone(), secrets.token_generator).await?;
let service = Service::spawn(config, persistence.clone()).await?;
let jwt_auth = secrets
let auth = secrets
.public_key
.map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
let router = make_router(service.clone(), jwt_auth, build_info)
let router = make_router(service.clone(), auth, build_info)
.build()
.map_err(|err| anyhow!(err))?;
let http_service =

View File

@@ -4,7 +4,6 @@ pub(crate) mod safekeeper_reconciler;
mod safekeeper_service;
mod tenant_shard_iterator;
use crate::hadron_token::HadronTokenGenerator;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
@@ -519,11 +518,6 @@ pub struct Service {
inner: Arc<std::sync::RwLock<ServiceState>>,
config: Config,
persistence: Arc<Persistence>,
// HadronTokenGenerator to generate (sign) JWTs during compute deployment and compute-spec generation.
#[allow(unused)]
token_generator: Option<HadronTokenGenerator>,
compute_hook: Arc<ComputeHook>,
result_tx: tokio::sync::mpsc::UnboundedSender<ReconcileResultRequest>,
@@ -1674,11 +1668,7 @@ impl Service {
}
}
pub async fn spawn(
config: Config,
persistence: Arc<Persistence>,
token_generator: Option<HadronTokenGenerator>,
) -> anyhow::Result<Arc<Self>> {
pub async fn spawn(config: Config, persistence: Arc<Persistence>) -> anyhow::Result<Arc<Self>> {
let (result_tx, result_rx) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = tokio::sync::mpsc::unbounded_channel();
@@ -1935,7 +1925,6 @@ impl Service {
))),
config: config.clone(),
persistence,
token_generator,
compute_hook: Arc::new(ComputeHook::new(config.clone())?),
result_tx,
heartbeater_ps,

View File

@@ -1230,10 +1230,7 @@ impl Service {
}
// It it is the same new_sk_set, we can continue the migration (retry).
} else {
let prev_finished = timeline.cplane_notified_generation == timeline.generation
&& timeline.sk_set_notified_generation == timeline.generation;
if !prev_finished {
if !is_migration_finished(&timeline) {
// The previous migration is committed, but the finish step failed.
// Safekeepers/cplane might not know about the last membership configuration.
// Retry the finish step to ensure smooth migration.
@@ -1545,6 +1542,8 @@ impl Service {
timeline_id: TimelineId,
timeline: &TimelinePersistence,
) -> Result<(), ApiError> {
tracing::info!(generation=?timeline.generation, sk_set=?timeline.sk_set, new_sk_set=?timeline.new_sk_set, "retrying finish safekeeper migration");
if timeline.new_sk_set.is_some() {
// Logical error, should never happen.
return Err(ApiError::InternalServerError(anyhow::anyhow!(
@@ -1624,4 +1623,120 @@ impl Service {
Ok(wal_positions[quorum_size - 1])
}
/// Abort ongoing safekeeper migration.
pub(crate) async fn tenant_timeline_safekeeper_migrate_abort(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<(), ApiError> {
// TODO(diko): per-tenant lock is too wide. Consider introducing per-timeline locks.
let _tenant_lock = trace_shared_lock(
&self.tenant_op_locks,
tenant_id,
TenantOperations::TimelineSafekeeperMigrate,
)
.await;
// Fetch current timeline configuration from the configuration storage.
let timeline = self
.persistence
.get_timeline(tenant_id, timeline_id)
.await?;
let Some(timeline) = timeline else {
return Err(ApiError::NotFound(
anyhow::anyhow!(
"timeline {tenant_id}/{timeline_id} doesn't exist in timelines table"
)
.into(),
));
};
let mut generation = SafekeeperGeneration::new(timeline.generation as u32);
let Some(new_sk_set) = &timeline.new_sk_set else {
// No new_sk_set -> no active migration that we can abort.
tracing::info!("timeline has no active migration");
if !is_migration_finished(&timeline) {
// The last migration is committed, but the finish step failed.
// Safekeepers/cplane might not know about the last membership configuration.
// Retry the finish step to make the timeline state clean.
self.finish_safekeeper_migration_retry(tenant_id, timeline_id, &timeline)
.await?;
}
return Ok(());
};
tracing::info!(sk_set=?timeline.sk_set, ?new_sk_set, ?generation, "aborting timeline migration");
let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?;
let new_safekeepers = self.get_safekeepers(new_sk_set)?;
let cur_sk_member_set =
Self::make_member_set(&cur_safekeepers).map_err(ApiError::InternalServerError)?;
// Increment current generation and remove new_sk_set from the timeline to abort the migration.
generation = generation.next();
let mconf = membership::Configuration {
generation,
members: cur_sk_member_set,
new_members: None,
};
// Exclude safekeepers which were added during the current migration.
let cur_ids: HashSet<NodeId> = cur_safekeepers.iter().map(|sk| sk.get_id()).collect();
let exclude_safekeepers = new_safekeepers
.into_iter()
.filter(|sk| !cur_ids.contains(&sk.get_id()))
.collect::<Vec<_>>();
let exclude_requests = exclude_safekeepers
.iter()
.map(|sk| TimelinePendingOpPersistence {
sk_id: sk.skp.id,
tenant_id: tenant_id.to_string(),
timeline_id: timeline_id.to_string(),
generation: generation.into_inner() as i32,
op_kind: SafekeeperTimelineOpKind::Exclude,
})
.collect::<Vec<_>>();
let cur_sk_set = cur_safekeepers
.iter()
.map(|sk| sk.get_id())
.collect::<Vec<_>>();
// Persist new mconf and exclude requests.
self.persistence
.update_timeline_membership(
tenant_id,
timeline_id,
generation,
&cur_sk_set,
None,
&exclude_requests,
)
.await?;
// At this point we have already commited the abort, but still need to notify
// cplane/safekeepers with the new mconf. That's what finish_safekeeper_migration does.
self.finish_safekeeper_migration(
tenant_id,
timeline_id,
&cur_safekeepers,
&mconf,
&exclude_safekeepers,
)
.await?;
Ok(())
}
}
fn is_migration_finished(timeline: &TimelinePersistence) -> bool {
timeline.cplane_notified_generation == timeline.generation
&& timeline.sk_set_notified_generation == timeline.generation
}

View File

@@ -13,11 +13,10 @@ if TYPE_CHECKING:
@dataclass
class AuthKeys:
priv: str
algorithm: str
def generate_token(self, *, scope: TokenScope, **token_data: Any) -> str:
token_data = {key: str(val) for key, val in token_data.items()}
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm=self.algorithm)
token = jwt.encode({"scope": scope, **token_data}, self.priv, algorithm="EdDSA")
# cast(Any, self.priv)
# jwt.encode can return 'bytes' or 'str', depending on Python version or type
@@ -47,4 +46,3 @@ class TokenScope(StrEnum):
TENANT = "tenant"
SCRUBBER = "scrubber"
INFRA = "infra"
TENANT_ENDPOINT = "tenantendpoint"

View File

@@ -28,15 +28,11 @@ import asyncpg
import backoff
import boto3
import httpx
import jwt
import psycopg2
import psycopg2.sql
import pytest
import requests
import toml
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from jwcrypto import jwk
# Type-related stuff
@@ -406,15 +402,6 @@ class PageserverImportConfig:
return ("timeline_import_config", value)
@dataclass
class HadronTokenDecoder:
public_key: str
algorithm: str
def decode_token(self, token: str) -> dict[str, Any]:
return jwt.decode(token, self.public_key, algorithms=[self.algorithm])
class NeonEnvBuilder:
"""
Builder object to create a Neon runtime environment
@@ -485,7 +472,6 @@ class NeonEnvBuilder:
self.safekeepers_id_start = safekeepers_id_start
self.safekeepers_enable_fsync = safekeepers_enable_fsync
self.auth_enabled = auth_enabled
self.use_hadron_auth_tokens = False
self.default_branch_name = default_branch_name
self.env: NeonEnv | None = None
self.keep_remote_storage_contents: bool = True
@@ -1135,11 +1121,6 @@ class NeonEnv:
self.repo_dir.joinpath("rootCA.crt") if self.generate_local_ssl_certs else None
)
# The auth token type used in the test environment. neon_local is instruted to generate key pairs
# according to the auth token type. The keys are always generated but are only used if
# config.auth_enabled == True.
self.auth_token_type: str = "HadronJWT" if config.use_hadron_auth_tokens else "NeonJWT"
neon_local_env_vars = {}
if self.rust_log_override is not None:
neon_local_env_vars["RUST_LOG"] = self.rust_log_override
@@ -1217,7 +1198,6 @@ class NeonEnv:
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
"auth_token_type": self.auth_token_type,
}
if config.use_https_storage_broker_api:
@@ -1265,9 +1245,9 @@ class NeonEnv:
)
# Create config for pageserver
http_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
pg_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
grpc_auth_type = self.auth_token_type if config.auth_enabled else "Trust"
http_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
for ps_id in range(
self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers
):
@@ -1405,8 +1385,9 @@ class NeonEnv:
"https_port": port.https,
"sync": config.safekeepers_enable_fsync,
"use_https_safekeeper_api": config.use_https_safekeeper_api,
"auth_type": self.auth_token_type if config.auth_enabled else "Trust",
}
if config.auth_enabled:
sk_cfg["auth_enabled"] = True
if self.safekeepers_remote_storage is not None:
sk_cfg["remote_storage"] = (
self.safekeepers_remote_storage.to_toml_inline_table().strip()
@@ -1597,66 +1578,29 @@ class NeonEnv:
@cached_property
def auth_keys(self) -> AuthKeys:
priv = (Path(self.repo_dir) / "auth_private_key.pem").read_text()
algorithm = "EdDSA" if self.auth_token_type == "NeonJWT" else "RS256"
return AuthKeys(priv=priv, algorithm=algorithm)
@cached_property
def hadron_token_decoder(self) -> HadronTokenDecoder:
cert = (Path(self.repo_dir) / "auth_public_key.pem").read_text()
x509_cert = x509.load_pem_x509_certificate(cert.encode(), default_backend())
pem_public_key = (
x509_cert.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode()
)
return HadronTokenDecoder(public_key=pem_public_key, algorithm="RS256")
return AuthKeys(priv=priv)
def regenerate_keys_at(self, privkey_path: Path, pubkey_path: Path):
if self.auth_token_type == "NeonJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
# compare generate_auth_keys() in local_env.rs
subprocess.run(
["openssl", "genpkey", "-algorithm", "ed25519", "-out", privkey_path],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
elif self.auth_token_type == "HadronJWT":
# compare generate_auth_keys() in local_env.rs
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:4096",
"-sha256",
"-keyout",
privkey_path,
"-out",
pubkey_path,
"-nodes",
"-subj",
"/CN=eng-brickstore@databricks.com",
],
cwd=self.repo_dir,
check=True,
)
subprocess.run(
[
"openssl",
"pkey",
"-in",
privkey_path,
"-pubout",
"-out",
pubkey_path,
],
cwd=self.repo_dir,
check=True,
)
del self.auth_keys
def generate_endpoint_id(self) -> str:
@@ -2077,10 +2021,10 @@ class NeonStorageController(MetricsGetter, LogUtils):
return resp
def headers(self, scope: TokenScope | None, **token_data: Any) -> dict[str, str]:
def headers(self, scope: TokenScope | None) -> dict[str, str]:
headers = {}
if self.auth_enabled and scope is not None:
jwt_token = self.env.auth_keys.generate_token(scope=scope, **token_data)
jwt_token = self.env.auth_keys.generate_token(scope=scope)
headers["Authorization"] = f"Bearer {jwt_token}"
return headers
@@ -2379,6 +2323,19 @@ class NeonStorageController(MetricsGetter, LogUtils):
response.raise_for_status()
log.info(f"migrate_safekeepers success: {response.json()}")
def abort_safekeeper_migration(
self,
tenant_id: TenantId,
timeline_id: TimelineId,
):
response = self.request(
"POST",
f"{self.api}/v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate_abort",
headers=self.headers(TokenScope.PAGE_SERVER_API),
)
response.raise_for_status()
log.info(f"abort_safekeeper_migration success: {response.json()}")
def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]:
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr": str, "listen_http_port": int}

View File

@@ -32,11 +32,8 @@ def assert_client_not_authorized(env: NeonEnv, http_client: PageserverHttpClient
assert_client_authorized(env, http_client)
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
def test_pageserver_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
ps = env.pageserver
@@ -75,10 +72,8 @@ def test_pageserver_auth(neon_env_builder: NeonEnvBuilder, use_hadron_auth_token
env.pageserver.tenant_create(TenantId.generate(), auth_token=tenant_token)
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()
@@ -96,10 +91,8 @@ def test_compute_auth_to_pageserver(neon_env_builder: NeonEnvBuilder, use_hadron
assert cur.fetchone() == (5000050000,)
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -152,10 +145,8 @@ def test_pageserver_multiple_keys(neon_env_builder: NeonEnvBuilder, use_hadron_a
assert_client_authorized(env, pageserver_http_client_new)
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth_tokens: bool):
def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(
[".*Authentication error: InvalidSignature.*", ".*Unauthorized: malformed jwt token.*"]
@@ -192,12 +183,7 @@ def test_pageserver_key_reload(neon_env_builder: NeonEnvBuilder, use_hadron_auth
@pytest.mark.parametrize("auth_enabled", [False, True])
@pytest.mark.parametrize("use_hadron_auth_tokens", [True, False])
def test_auth_failures(
neon_env_builder: NeonEnvBuilder, auth_enabled: bool, use_hadron_auth_tokens: bool
):
neon_env_builder.auth_enabled = auth_enabled
neon_env_builder.use_hadron_auth_tokens = use_hadron_auth_tokens
def test_auth_failures(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
neon_env_builder.auth_enabled = auth_enabled
env = neon_env_builder.init_start()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from fixtures.metrics import parse_metrics
@@ -9,13 +10,13 @@ if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_compute_monitor(neon_simple_env: NeonEnv):
def test_compute_monitor_downtime_calculation(neon_simple_env: NeonEnv):
"""
Test that compute_ctl can detect Postgres going down (unresponsive) and
reconnect when it comes back online. Also check that the downtime metrics
are properly emitted.
"""
TEST_DB = "test_compute_monitor"
TEST_DB = "test_compute_monitor_downtime_calculation"
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
@@ -68,3 +69,56 @@ def test_compute_monitor(neon_simple_env: NeonEnv):
# Just a sanity check that we log the downtime info
endpoint.log_contains("downtime_info")
def test_compute_monitor_activity(neon_simple_env: NeonEnv):
"""
Test compute monitor correctly detects user activity inside Postgres
and updates last_active timestamp in the /status response.
"""
TEST_DB = "test_compute_monitor_activity_db"
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
with endpoint.cursor() as cursor:
# Create a new database because `postgres` DB is excluded
# from activity monitoring.
cursor.execute(f"CREATE DATABASE {TEST_DB}")
client = endpoint.http_client()
prev_last_active = None
def check_last_active():
nonlocal prev_last_active
with endpoint.cursor(dbname=TEST_DB) as cursor:
# Execute some dummy query to generate 'activity'.
cursor.execute("SELECT * FROM generate_series(1, 10000)")
status = client.status()
assert status["last_active"] is not None
prev_last_active = status["last_active"]
wait_until(check_last_active)
assert prev_last_active is not None
# Sleep for everything to settle down. It's not strictly necessary,
# but should still remove any potential noise and/or prevent test from passing
# even if compute monitor is not working.
time.sleep(3)
with endpoint.cursor(dbname=TEST_DB) as cursor:
cursor.execute("SELECT * FROM generate_series(1, 10000)")
def check_last_active_updated():
nonlocal prev_last_active
status = client.status()
assert status["last_active"] is not None
assert status["last_active"] != prev_last_active
assert status["last_active"] > prev_last_active
wait_until(check_last_active_updated)

View File

@@ -145,6 +145,7 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod):
stop_and_check_lsn(secondary, None)
if method == PromoteMethod.COMPUTE_CTL:
log.info("Restarting primary to check new config")
secondary.stop()
# In production, compute ultimately receives new compute spec from cplane.
secondary.respec(mode="Primary")

View File

@@ -460,3 +460,91 @@ def test_pull_from_most_advanced_sk(neon_env_builder: NeonEnvBuilder):
ep.start(safekeeper_generation=5, safekeepers=new_sk_set2)
assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)]
def test_abort_safekeeper_migration(neon_env_builder: NeonEnvBuilder):
"""
Test that safekeeper migration can be aborted.
1. Insert failpoints and ensure the abort successfully reverts the timeline state.
2. Check that endpoint is operational after the abort.
"""
neon_env_builder.num_safekeepers = 2
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 1,
}
env = neon_env_builder.init_start()
env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS)
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert len(mconf["sk_set"]) == 1
cur_sk = mconf["sk_set"][0]
cur_gen = 1
ep = env.endpoints.create("main", tenant_id=env.initial_tenant)
ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"])
ep.safe_psql("CREATE EXTENSION neon_test_utils;")
ep.safe_psql("CREATE TABLE t(a int)")
ep.safe_psql("INSERT INTO t VALUES (1)")
another_sk = [sk.id for sk in env.safekeepers if sk.id != cur_sk][0]
failpoints = [
"sk-migration-after-step-3",
"sk-migration-after-step-4",
"sk-migration-after-step-5",
"sk-migration-after-step-7",
]
for fp in failpoints:
env.storage_controller.configure_failpoints((fp, "return(1)"))
with pytest.raises(StorageControllerApiException, match=f"failpoint {fp}"):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [another_sk]
)
cur_gen += 1
env.storage_controller.configure_failpoints((fp, "off"))
# We should have a joint mconf after the failure.
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["generation"] == cur_gen
assert mconf["sk_set"] == [cur_sk]
assert mconf["new_sk_set"] == [another_sk]
env.storage_controller.abort_safekeeper_migration(env.initial_tenant, env.initial_timeline)
cur_gen += 1
# Abort should revert the timeline to the previous sk_set and increment the generation.
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["generation"] == cur_gen
assert mconf["sk_set"] == [cur_sk]
assert mconf["new_sk_set"] is None
assert ep.safe_psql("SHOW neon.safekeepers")[0][0].startswith(f"g#{cur_gen}:")
ep.safe_psql(f"INSERT INTO t VALUES ({cur_gen})")
# After step-8 the final mconf is committed and the migration is not abortable anymore.
# So the abort should not abort anything.
env.storage_controller.configure_failpoints(("sk-migration-after-step-8", "return(1)"))
with pytest.raises(StorageControllerApiException, match="failpoint sk-migration-after-step-8"):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, [another_sk]
)
cur_gen += 2
env.storage_controller.configure_failpoints((fp, "off"))
env.storage_controller.abort_safekeeper_migration(env.initial_tenant, env.initial_timeline)
# The migration is fully committed, no abort should have been performed.
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["generation"] == cur_gen
assert mconf["sk_set"] == [another_sk]
assert mconf["new_sk_set"] is None
ep.safe_psql(f"INSERT INTO t VALUES ({cur_gen})")
ep.clear_buffers()
assert ep.safe_psql("SELECT * FROM t") == [(i + 1,) for i in range(cur_gen) if i % 2 == 0]

View File

@@ -1406,9 +1406,6 @@ def test_storage_controller_s3_time_travel_recovery(
def test_storage_controller_auth(neon_env_builder: NeonEnvBuilder):
neon_env_builder.auth_enabled = True
env = neon_env_builder.init_start()
assert env.auth_token_type == "NeonJWT"
svc = env.storage_controller
api = env.storage_controller_api

View File

@@ -78,7 +78,6 @@ parquet = { version = "53", default-features = false, features = ["zstd"] }
portable-atomic = { version = "1", features = ["require-cas"] }
prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] }
rand = { version = "0.9" }
rcgen = { version = "0.13", features = ["aws_lc_rs"] }
regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" }
@@ -127,7 +126,6 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
displaydoc = { version = "0.2" }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
half = { version = "2", default-features = false, features = ["num-traits"] }