Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Chi Z
7bd3e7a803 feat(scrubber): more parallelism for metadata check
Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-05-06 17:28:13 +08:00
70 changed files with 793 additions and 1441 deletions

View File

@@ -53,77 +53,6 @@ concurrency:
cancel-in-progress: true
jobs:
cleanup:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
env:
ORG_ID: org-solitary-dew-09443886
LIMIT: 100
SEARCH: "GITHUB_RUN_ID="
BASE_URL: https://console-stage.neon.build/api/v2
DRY_RUN: "false" # Set to "true" to just test out the workflow
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cleanup inactive Neon projects left over from prior runs
env:
API_KEY: ${{ secrets.NEON_STAGING_API_KEY }}
run: |
set -euo pipefail
NOW=$(date -u +%s)
DAYS_AGO=$((NOW - 5 * 86400))
REQUEST_URL="$BASE_URL/projects?limit=$LIMIT&search=$(printf '%s' "$SEARCH" | jq -sRr @uri)&org_id=$ORG_ID"
echo "Requesting project list from:"
echo "$REQUEST_URL"
response=$(curl -s -X GET "$REQUEST_URL" \
--header "Accept: application/json" \
--header "Content-Type: application/json" \
--header "Authorization: Bearer ${API_KEY}" )
echo "Response:"
echo "$response" | jq .
projects_to_delete=$(echo "$response" | jq --argjson cutoff "$DAYS_AGO" '
.projects[]
| select(.compute_last_active_at != null)
| select((.compute_last_active_at | fromdateiso8601) < $cutoff)
| {id, name, compute_last_active_at}
')
if [ -z "$projects_to_delete" ]; then
echo "No projects eligible for deletion."
exit 0
fi
echo "Projects that will be deleted:"
echo "$projects_to_delete" | jq -r '.id'
if [ "$DRY_RUN" = "false" ]; then
echo "$projects_to_delete" | jq -r '.id' | while read -r project_id; do
echo "Deleting project: $project_id"
curl -s -X DELETE "$BASE_URL/projects/$project_id" \
--header "Accept: application/json" \
--header "Content-Type: application/json" \
--header "Authorization: Bearer ${API_KEY}"
done
else
echo "Dry run enabled — no projects were deleted."
fi
bench:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:

2
Cargo.lock generated
View File

@@ -1284,7 +1284,6 @@ name = "compute_tools"
version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"aws-config",
"aws-sdk-kms",
"aws-sdk-s3",
@@ -1421,7 +1420,6 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"endpoint_storage",
"futures",
"http-utils",
"humantime",

View File

@@ -243,7 +243,6 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver = { path = "./pageserver" }

View File

@@ -1084,12 +1084,23 @@ RUN cargo install --locked --version 0.12.9 cargo-pgrx && \
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
#########################################################################################
#
# Layer "rust extensions pgrx14"
#
# Version 14 is now required by a few
#########################################################################################
FROM pg-build-nonroot-with-cargo AS rust-extensions-build-pgrx14
ARG PG_VERSION
RUN cargo install --locked --version 0.14.1 cargo-pgrx && \
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
#########################################################################################
#
# Layer "rust extensions pgrx14"
#
# Version 14 is now required by a few
# This layer should be used as a base for new pgrx extensions,
# and eventually get merged with `rust-extensions-build`
#
@@ -1322,8 +1333,8 @@ ARG PG_VERSION
# Do not update without approve from proxy team
# Make sure the version is reflected in proxy/src/serverless/local_conn_pool.rs
WORKDIR /ext-src
RUN wget https://github.com/neondatabase/pg_session_jwt/archive/refs/tags/v0.3.1.tar.gz -O pg_session_jwt.tar.gz && \
echo "62fec9e472cb805c53ba24a0765afdb8ea2720cfc03ae7813e61687b36d1b0ad pg_session_jwt.tar.gz" | sha256sum --check && \
RUN wget https://github.com/neondatabase/pg_session_jwt/archive/refs/tags/v0.3.0.tar.gz -O pg_session_jwt.tar.gz && \
echo "19be2dc0b3834d643706ed430af998bb4c2cdf24b3c45e7b102bb3a550e8660c pg_session_jwt.tar.gz" | sha256sum --check && \
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
sed -i 's/pgrx = "0.12.6"/pgrx = { version = "0.12.9", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/version = "0.12.6"/version = "0.12.9"/g' pgrx-tests/Cargo.toml && \
@@ -1351,8 +1362,7 @@ COPY compute/patches/anon_v2.patch .
# This is an experimental extension, never got to real production.
# !Do not remove! It can be present in shared_preload_libraries and compute will fail to start if library is not found.
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/2.1.0/postgresql_anonymizer-latest.tar.gz -O pg_anon.tar.gz && \
echo "48e7f5ae2f1ca516df3da86c5c739d48dd780a4e885705704ccaad0faa89d6c0 pg_anon.tar.gz" | sha256sum --check && \
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/latest/postgresql_anonymizer-latest.tar.gz -O pg_anon.tar.gz && \
mkdir pg_anon-src && cd pg_anon-src && tar xzf ../pg_anon.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "=0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \

View File

@@ -10,7 +10,6 @@ default = []
testing = ["fail/failpoints"]
[dependencies]
async-compression.workspace = true
base64.workspace = true
aws-config.workspace = true
aws-sdk-s3.workspace = true

View File

@@ -1,10 +1,17 @@
use std::collections::HashMap;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege;
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus, LfcOffloadState,
LfcPrewarmState,
};
use compute_api::responses::{ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent,
};
@@ -18,16 +25,6 @@ use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tracing::{Instrument, debug, error, info, instrument, warn};
use utils::id::{TenantId, TimelineId};
@@ -153,9 +150,6 @@ pub struct ComputeState {
/// set up the span relationship ourselves.
pub startup_span: Option<tracing::span::Span>,
pub lfc_prewarm_state: LfcPrewarmState,
pub lfc_offload_state: LfcOffloadState,
pub metrics: ComputeMetrics,
}
@@ -169,8 +163,6 @@ impl ComputeState {
pspec: None,
startup_span: None,
metrics: ComputeMetrics::default(),
lfc_prewarm_state: LfcPrewarmState::default(),
lfc_offload_state: LfcOffloadState::default(),
}
}
@@ -206,8 +198,6 @@ pub struct ParsedSpec {
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
pub endpoint_storage_addr: Option<SocketAddr>,
pub endpoint_storage_token: Option<String>,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
@@ -261,18 +251,6 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<SocketAddr> = spec
.endpoint_storage_addr
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"))
.unwrap_or_default()
.parse()
.ok();
let endpoint_storage_token = spec
.endpoint_storage_token
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_token"));
Ok(ParsedSpec {
spec,
pageserver_connstr,
@@ -280,8 +258,6 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
storage_auth_token,
tenant_id,
timeline_id,
endpoint_storage_addr,
endpoint_storage_token,
})
}
}
@@ -760,9 +736,6 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
if pspec.spec.prewarm_lfc_on_startup {
self.prewarm_lfc();
}
Ok(())
}

View File

@@ -1,202 +0,0 @@
use crate::compute::ComputeNode;
use anyhow::{Context, Result, bail};
use async_compression::tokio::bufread::{ZstdDecoder, ZstdEncoder};
use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use http::StatusCode;
use reqwest::Client;
use std::sync::Arc;
use tokio::{io::AsyncReadExt, spawn};
use tracing::{error, info};
#[derive(serde::Serialize, Default)]
pub struct LfcPrewarmStateWithProgress {
#[serde(flatten)]
base: LfcPrewarmState,
total: i32,
prewarmed: i32,
skipped: i32,
}
/// A pair of url and a token to query endpoint storage for LFC prewarm-related tasks
struct EndpointStoragePair {
url: String,
token: String,
}
const KEY: &str = "lfc_state";
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
};
let tenant_id = pspec.tenant_id;
let timeline_id = pspec.timeline_id;
let url = format!("http://{base_uri}/{tenant_id}/{timeline_id}/{endpoint_id}/{KEY}");
let Some(ref token) = pspec.endpoint_storage_token else {
bail!("pspec.endpoint_storage_token missing")
};
let token = token.clone();
Ok(EndpointStoragePair { url, token })
}
}
impl ComputeNode {
// If prewarm failed, we want to get overall number of segments as well as done ones.
// However, this function should be reliable even if querying postgres failed.
pub async fn lfc_prewarm_state(&self) -> LfcPrewarmStateWithProgress {
info!("requesting LFC prewarm state from postgres");
let mut state = LfcPrewarmStateWithProgress::default();
{
state.base = self.state.lock().unwrap().lfc_prewarm_state.clone();
}
let client = match ComputeNode::get_maintenance_client(&self.tokio_conn_conf).await {
Ok(client) => client,
Err(err) => {
error!(%err, "connecting to postgres");
return state;
}
};
let row = match client
.query_one("select * from get_prewarm_info()", &[])
.await
{
Ok(row) => row,
Err(err) => {
error!(%err, "querying LFC prewarm status");
return state;
}
};
state.total = row.try_get(0).unwrap_or_default();
state.prewarmed = row.try_get(1).unwrap_or_default();
state.skipped = row.try_get(2).unwrap_or_default();
state
}
pub fn lfc_offload_state(&self) -> LfcOffloadState {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
if let LfcPrewarmState::Prewarming =
std::mem::replace(state, LfcPrewarmState::Prewarming)
{
return false;
}
}
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl().await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
error!(%err);
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
};
});
true
}
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
state.pspec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
}
let mut uncompressed = Vec::new();
let lfc_state = res
.bytes()
.await
.context("getting request body from endpoint storage")?;
ZstdDecoder::new(lfc_state.iter().as_slice())
.read_to_end(&mut uncompressed)
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
.query_one("select prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())
}
/// Returns false if there is an offload request ongoing, true otherwise
pub fn offload_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_OFFLOAD_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if let LfcOffloadState::Offloading =
std::mem::replace(state, LfcOffloadState::Offloading)
{
return false;
}
}
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.offload_lfc_impl().await else {
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
error!(%err);
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
});
true
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
.query_one("select get_local_cache_state()", &[])
.await
.context("querying LFC state")?
.try_get::<usize, &[u8]>(0)
.context("deserializing LFC state")
.map(ZstdEncoder::new)?
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}
}

View File

@@ -223,9 +223,6 @@ pub fn write_postgres_conf(
// TODO: tune this after performance testing
writeln!(file, "pgaudit.log_rotation_age=5")?;
// Enable audit logs for pg_session_jwt extension
writeln!(file, "pg_session_jwt.audit_log=on")?;
// Add audit shared_preload_libraries, if they are not present.
//
// The caller who sets the flag is responsible for ensuring that the necessary

View File

@@ -1,10 +1,12 @@
use std::collections::HashSet;
use anyhow::{Result, anyhow};
use axum::{RequestExt, body::Body};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope};
use compute_api::requests::ComputeClaims;
use futures::future::BoxFuture;
use http::{Request, Response, StatusCode};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
@@ -23,14 +25,13 @@ pub(in crate::http) struct Authorize {
impl Authorize {
pub fn new(compute_id: String, jwks: JwkSet) -> Self {
let mut validation = Validation::new(Algorithm::EdDSA);
// Nothing is currently required
validation.required_spec_claims = HashSet::new();
validation.validate_exp = true;
// Unused by the control plane
validation.validate_nbf = false;
// Unused by the control plane
validation.validate_aud = false;
validation.set_audience(&[COMPUTE_AUDIENCE]);
// Nothing is currently required
validation.set_required_spec_claims(&[] as &[&str; 0]);
// Unused by the control plane
validation.validate_nbf = false;
Self {
compute_id,
@@ -63,47 +64,11 @@ impl AsyncAuthorizeRequest<Body> for Authorize {
Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
};
match data.claims.scope {
// TODO: We should validate audience for every token, but
// instead of this ad-hoc validation, we should turn
// [`Validation::validate_aud`] on. This is merely a stopgap
// while we roll out `aud` deployment. We return a 401
// Unauthorized because when we eventually do use
// [`Validation`], we will hit the above `Err` match arm which
// returns 401 Unauthorized.
Some(ComputeClaimsScope::Admin) => {
let Some(ref audience) = data.claims.audience else {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"missing audience in authorization token claims",
));
};
if !audience.iter().any(|a| a == COMPUTE_AUDIENCE) {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"invalid audience in authorization token claims",
));
}
}
// If the scope is not [`ComputeClaimsScope::Admin`], then we
// must validate the compute_id
_ => {
let Some(ref claimed_compute_id) = data.claims.compute_id else {
return Err(JsonResponse::error(
StatusCode::FORBIDDEN,
"missing compute_id in authorization token claims",
));
};
if *claimed_compute_id != compute_id {
return Err(JsonResponse::error(
StatusCode::FORBIDDEN,
"invalid compute ID in authorization token claims",
));
}
}
if data.claims.compute_id != compute_id {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"invalid compute ID in authorization token claims",
));
}
// Make claims available to any subsequent middleware or request

View File

@@ -1,39 +0,0 @@
use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json<LfcPrewarmStateWithProgress> {
Json(compute.lfc_prewarm_state().await)
}
// Following functions are marked async for axum, as it's more convenient than wrapping these
// in async lambdas at call site
pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadState> {
Json(compute.lfc_offload_state())
}
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(
StatusCode::TOO_MANY_REQUESTS,
"Multiple requests for prewarm are not allowed",
)
}
}
pub(in crate::http) async fn offload(compute: Compute) -> Response {
if compute.offload_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(
StatusCode::TOO_MANY_REQUESTS,
"Multiple requests for prewarm offload are not allowed",
)
}
}

View File

@@ -11,7 +11,6 @@ pub(in crate::http) mod extensions;
pub(in crate::http) mod failpoints;
pub(in crate::http) mod grants;
pub(in crate::http) mod insights;
pub(in crate::http) mod lfc;
pub(in crate::http) mod metrics;
pub(in crate::http) mod metrics_json;
pub(in crate::http) mod status;

View File

@@ -23,7 +23,7 @@ use super::{
middleware::authorize::Authorize,
routes::{
check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
grants, insights, lfc, metrics, metrics_json, status, terminate,
grants, insights, metrics, metrics_json, status, terminate,
},
};
use crate::compute::ComputeNode;
@@ -85,8 +85,6 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))
.route("/lfc/offload", get(lfc::offload_state).post(lfc::offload))
.route("/check_writability", post(check_writability::is_writable))
.route("/configure", post(configure::configure))
.route("/database_schema", get(database_schema::get_schema_dump))

View File

@@ -11,7 +11,6 @@ pub mod http;
pub mod logger;
pub mod catalog;
pub mod compute;
pub mod compute_prewarm;
pub mod disk_quota;
pub mod extension_server;
pub mod installed_extensions;

View File

@@ -1,7 +1,7 @@
use metrics::core::{AtomicF64, AtomicU64, Collector, GenericCounter, GenericGauge};
use metrics::proto::MetricFamily;
use metrics::{
IntCounter, IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter,
IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter,
register_int_counter_vec, register_int_gauge_vec, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
@@ -97,24 +97,6 @@ pub(crate) static PG_TOTAL_DOWNTIME_MS: Lazy<GenericCounter<AtomicU64>> = Lazy::
.expect("failed to define a metric")
});
/// Needed as neon.file_cache_prewarm_batch == 0 doesn't mean we never tried to prewarm.
/// On the other hand, LFC_PREWARMED_PAGES is excessive as we can GET /lfc/prewarm
pub(crate) static LFC_PREWARM_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_prewarm_requests_total",
"Total number of LFC prewarm requests made by compute_ctl",
)
.expect("failed to define a metric")
});
pub(crate) static LFC_OFFLOAD_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_offload_requests_total",
"Total number of LFC offload requests made by compute_ctl",
)
.expect("failed to define a metric")
});
pub fn collect() -> Vec<MetricFamily> {
let mut metrics = COMPUTE_CTL_UP.collect();
metrics.extend(INSTALLED_EXTENSIONS.collect());
@@ -124,7 +106,5 @@ pub fn collect() -> Vec<MetricFamily> {
metrics.extend(AUDIT_LOG_DIR_SIZE.collect());
metrics.extend(PG_CURR_DOWNTIME_MS.collect());
metrics.extend(PG_TOTAL_DOWNTIME_MS.collect());
metrics.extend(LFC_PREWARM_REQUESTS.collect());
metrics.extend(LFC_OFFLOAD_REQUESTS.collect());
metrics
}

View File

@@ -30,7 +30,6 @@ mod pg_helpers_tests {
r#"fsync = off
wal_level = logical
hot_standby = on
prewarm_lfc_on_startup = off
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on

View File

@@ -41,7 +41,7 @@ storage_broker.workspace = true
http-utils.workspace = true
utils.workspace = true
whoami.workspace = true
endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true

View File

@@ -16,11 +16,10 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use control_plane::broker::StorageBroker;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_PORT, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
EndpointStorageConf, InitForceMode, LocalEnv, NeonBroker, NeonLocalInitConf,
@@ -706,9 +705,6 @@ struct EndpointStopCmdArgs {
struct EndpointGenerateJwtCmdArgs {
#[clap(help = "Postgres endpoint id")]
endpoint_id: String,
#[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
scope: Option<ComputeClaimsScope>,
}
#[derive(clap::Subcommand)]
@@ -1022,7 +1018,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
})
.collect(),
endpoint_storage: EndpointStorageConf {
listen_addr: ENDPOINT_STORAGE_DEFAULT_ADDR,
port: ENDPOINT_STORAGE_DEFAULT_PORT,
},
pg_distrib_dir: None,
neon_distrib_dir: None,
@@ -1488,25 +1484,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
None
};
let exp = (std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?
+ Duration::from_secs(86400))
.as_secs();
let claims = endpoint_storage::claims::EndpointStorageClaims {
tenant_id: endpoint.tenant_id,
timeline_id: endpoint.timeline_id,
endpoint_id: endpoint_id.to_string(),
exp,
};
let endpoint_storage_token = env.generate_auth_token(&claims)?;
let endpoint_storage_addr = env.endpoint_storage.listen_addr.to_string();
println!("Starting existing endpoint {endpoint_id}...");
endpoint
.start(
&auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageservers,
@@ -1559,16 +1540,12 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
endpoint.stop(&args.mode, args.destroy)?;
}
EndpointCmd::GenerateJwt(args) => {
let endpoint = {
let endpoint_id = &args.endpoint_id;
cplane
.endpoints
.get(endpoint_id)
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?
};
let jwt = endpoint.generate_jwt(args.scope)?;
let endpoint_id = &args.endpoint_id;
let endpoint = cplane
.endpoints
.get(endpoint_id)
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let jwt = endpoint.generate_jwt()?;
print!("{jwt}");
}

View File

@@ -45,9 +45,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
use compute_api::requests::{ComputeClaims, ConfigurationRequest};
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig,
};
@@ -632,17 +630,9 @@ impl Endpoint {
}
/// Generate a JWT with the correct claims.
pub fn generate_jwt(&self, scope: Option<ComputeClaimsScope>) -> Result<String> {
pub fn generate_jwt(&self) -> Result<String> {
self.env.generate_auth_token(&ComputeClaims {
audience: match scope {
Some(ComputeClaimsScope::Admin) => Some(vec![COMPUTE_AUDIENCE.to_owned()]),
_ => None,
},
compute_id: match scope {
Some(ComputeClaimsScope::Admin) => None,
_ => Some(self.endpoint_id.clone()),
},
scope,
compute_id: self.endpoint_id.clone(),
})
}
@@ -650,8 +640,6 @@ impl Endpoint {
pub async fn start(
&self,
auth_token: &Option<String>,
endpoint_storage_token: String,
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(Host, u16)>,
@@ -745,9 +733,6 @@ impl Endpoint {
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
audit_log_level: ComputeAudit::Disabled,
logs_export_host: None::<String>,
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
prewarm_lfc_on_startup: false,
};
// this strange code is needed to support respec() in tests
@@ -918,7 +903,7 @@ impl Endpoint {
self.external_http_address.port()
),
)
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.bearer_auth(self.generate_jwt()?)
.send()
.await?;
@@ -995,7 +980,7 @@ impl Endpoint {
self.external_http_address.port()
))
.header(CONTENT_TYPE.as_str(), "application/json")
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.bearer_auth(self.generate_jwt()?)
.body(
serde_json::to_string(&ConfigurationRequest {
spec,

View File

@@ -3,19 +3,17 @@ use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
/// Directory within .neon which will be used by default for LocalFs remote storage.
pub const ENDPOINT_STORAGE_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/endpoint_storage";
pub const ENDPOINT_STORAGE_DEFAULT_ADDR: SocketAddr =
SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 9993);
pub const ENDPOINT_STORAGE_DEFAULT_PORT: u16 = 9993;
pub struct EndpointStorage {
pub bin: Utf8PathBuf,
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub port: u16,
}
impl EndpointStorage {
@@ -24,7 +22,7 @@ impl EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
port: env.endpoint_storage.port,
}
}
@@ -33,7 +31,7 @@ impl EndpointStorage {
}
fn listen_addr(&self) -> Utf8PathBuf {
format!("{}:{}", self.addr.ip(), self.addr.port()).into()
format!("127.0.0.1:{}", self.port).into()
}
pub fn init(&self) -> Result<()> {

View File

@@ -20,9 +20,7 @@ use utils::auth::encode_from_key_file;
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::broker::StorageBroker;
use crate::endpoint_storage::{
ENDPOINT_STORAGE_DEFAULT_ADDR, ENDPOINT_STORAGE_REMOTE_STORAGE_DIR, EndpointStorage,
};
use crate::endpoint_storage::{ENDPOINT_STORAGE_REMOTE_STORAGE_DIR, EndpointStorage};
use crate::pageserver::{PAGESERVER_REMOTE_STORAGE_DIR, PageServerNode};
use crate::safekeeper::SafekeeperNode;
@@ -153,10 +151,10 @@ pub struct NeonLocalInitConf {
pub generate_local_ssl_certs: bool,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
#[derive(Serialize, Default, Deserialize, PartialEq, Eq, Clone, Debug)]
#[serde(default)]
pub struct EndpointStorageConf {
pub listen_addr: SocketAddr,
pub port: u16,
}
/// Broker config for cluster internal communication.
@@ -243,14 +241,6 @@ impl Default for NeonStorageControllerConf {
}
}
impl Default for EndpointStorageConf {
fn default() -> Self {
Self {
listen_addr: ENDPOINT_STORAGE_DEFAULT_ADDR,
}
}
}
impl NeonBroker {
pub fn client_url(&self) -> Url {
let url = if let Some(addr) = self.listen_https_addr {

View File

@@ -12,7 +12,6 @@ ERROR: invalid JWT encoding
-- Test creating a session with an expired JWT
SELECT auth.jwt_session_init('eyJhbGciOiJFZERTQSJ9.eyJleHAiOjE3NDI1NjQ0MzIsImlhdCI6MTc0MjU2NDI1MiwianRpIjo0MjQyNDIsInN1YiI6InVzZXIxMjMifQ.A6FwKuaSduHB9O7Gz37g0uoD_U9qVS0JNtT7YABGVgB7HUD1AMFc9DeyhNntWBqncg8k5brv-hrNTuUh5JYMAw');
ERROR: Token used after it has expired
DETAIL: exp=1742564432
-- Test creating a session with a valid JWT
SELECT auth.jwt_session_init('eyJhbGciOiJFZERTQSJ9.eyJleHAiOjQ4OTYxNjQyNTIsImlhdCI6MTc0MjU2NDI1MiwianRpIjo0MzQzNDMsInN1YiI6InVzZXIxMjMifQ.2TXVgjb6JSUq6_adlvp-m_SdOxZSyGS30RS9TLB0xu2N83dMSs2NybwE1NMU8Fb0tcAZR_ET7M2rSxbTrphfCg');
jwt_session_init

View File

@@ -343,7 +343,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
TimelineId::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
fn token() -> String {
let claims = endpoint_storage::claims::EndpointStorageClaims {
let claims = endpoint_storage::Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
@@ -489,8 +489,16 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
}
fn delete_prefix_token(uri: &str) -> String {
use serde::Serialize;
let parts = uri.split("/").collect::<Vec<&str>>();
let claims = endpoint_storage::claims::DeletePrefixClaims {
#[derive(Serialize)]
struct PrefixClaims {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<endpoint_storage::EndpointId>,
exp: u64,
}
let claims = PrefixClaims {
tenant_id: parts.get(1).map(|c| c.parse().unwrap()).unwrap(),
timeline_id: parts.get(2).map(|c| c.parse().unwrap()),
endpoint_id: parts.get(3).map(ToString::to_string),

View File

@@ -1,52 +0,0 @@
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use utils::id::{EndpointId, TenantId, TimelineId};
/// Claims to add, remove, or retrieve endpoint data. Used by compute_ctl
#[derive(Deserialize, Serialize, PartialEq)]
pub struct EndpointStorageClaims {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub endpoint_id: EndpointId,
pub exp: u64,
}
/// Claims to remove tenant, timeline, or endpoint data. Used by control plane
#[derive(Deserialize, Serialize, PartialEq)]
pub struct DeletePrefixClaims {
pub tenant_id: TenantId,
/// None when tenant is deleted (endpoint_id is also None in this case)
pub timeline_id: Option<TimelineId>,
/// None when timeline is deleted
pub endpoint_id: Option<EndpointId>,
pub exp: u64,
}
impl Display for EndpointStorageClaims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"EndpointClaims(tenant_id={} timeline_id={} endpoint_id={} exp={})",
self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
)
}
}
impl Display for DeletePrefixClaims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DeletePrefixClaims(tenant_id={} timeline_id={} endpoint_id={}, exp={})",
self.tenant_id,
self.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.exp
)
}
}

View File

@@ -1,5 +1,3 @@
pub mod claims;
use crate::claims::{DeletePrefixClaims, EndpointStorageClaims};
use anyhow::Result;
use axum::extract::{FromRequestParts, Path};
use axum::response::{IntoResponse, Response};
@@ -15,7 +13,7 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::id::{EndpointId, TenantId, TimelineId};
use utils::id::{TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
@@ -81,6 +79,26 @@ pub struct Storage {
pub max_upload_file_limit: usize,
}
pub type EndpointId = String; // If needed, reuse small string from proxy/src/types.rc
#[derive(Deserialize, Serialize, PartialEq)]
pub struct Claims {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub endpoint_id: EndpointId,
pub exp: u64,
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Claims(tenant_id {} timeline_id {} endpoint_id {} exp {})",
self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
)
}
}
#[derive(Deserialize, Serialize)]
struct KeyRequest {
tenant_id: TenantId,
@@ -89,13 +107,6 @@ struct KeyRequest {
path: String,
}
#[derive(Deserialize, Serialize, PartialEq)]
struct PrefixKeyRequest {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<EndpointId>,
}
#[derive(Debug, PartialEq)]
pub struct S3Path {
pub path: RemotePath,
@@ -154,7 +165,7 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: EndpointStorageClaims = state
let claims: Claims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
@@ -167,7 +178,7 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
path.endpoint_id.clone()
};
let route = EndpointStorageClaims {
let route = Claims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id,
@@ -182,13 +193,38 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
}
}
#[derive(Deserialize, Serialize, PartialEq)]
pub struct PrefixKeyPath {
pub tenant_id: TenantId,
pub timeline_id: Option<TimelineId>,
pub endpoint_id: Option<EndpointId>,
}
impl Display for PrefixKeyPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PrefixKeyPath(tenant_id {} timeline_id {} endpoint_id {})",
self.tenant_id,
self.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string())
)
}
}
#[derive(Debug, PartialEq)]
pub struct PrefixS3Path {
pub path: RemotePath,
}
impl From<&DeletePrefixClaims> for PrefixS3Path {
fn from(path: &DeletePrefixClaims) -> Self {
impl From<&PrefixKeyPath> for PrefixS3Path {
fn from(path: &PrefixKeyPath) -> Self {
let timeline_id = path
.timeline_id
.as_ref()
@@ -214,27 +250,21 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
state: &Arc<Storage>,
) -> Result<Self, Self::Rejection> {
let Path(path) = parts
.extract::<Path<PrefixKeyRequest>>()
.extract::<Path<PrefixKeyPath>>()
.await
.map_err(|e| bad_request(e, "invalid route"))?;
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: DeletePrefixClaims = state
let claims: PrefixKeyPath = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id: path.endpoint_id,
exp: claims.exp,
};
if route != claims {
return Err(unauthorized(route, claims));
if path != claims {
return Err(unauthorized(path, claims));
}
Ok((&route).into())
Ok((&path).into())
}
}
@@ -267,7 +297,7 @@ mod tests {
#[test]
fn s3_path() {
let auth = EndpointStorageClaims {
let auth = Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
@@ -297,11 +327,10 @@ mod tests {
#[test]
fn prefix_s3_path() {
let mut path = DeletePrefixClaims {
let mut path = PrefixKeyPath {
tenant_id: TENANT_ID,
timeline_id: None,
endpoint_id: None,
exp: 0,
};
let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
assert_eq!(

View File

@@ -1,58 +1,16 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use crate::privilege::Privilege;
use crate::responses::ComputeCtlConfig;
use crate::spec::{ComputeSpec, ExtVersion, PgIdent};
/// The value to place in the [`ComputeClaims::audience`] claim.
pub static COMPUTE_AUDIENCE: &str = "compute";
/// Available scopes for a compute's JWT.
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ComputeClaimsScope {
/// An admin-scoped token allows access to all of `compute_ctl`'s authorized
/// facilities.
Admin,
}
impl FromStr for ComputeClaimsScope {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"admin" => Ok(ComputeClaimsScope::Admin),
_ => Err(anyhow::anyhow!("invalid compute claims scope \"{s}\"")),
}
}
}
/// When making requests to the `compute_ctl` external HTTP server, the client
/// must specify a set of claims in `Authorization` header JWTs such that
/// `compute_ctl` can authorize the request.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename = "snake_case")]
pub struct ComputeClaims {
/// The compute ID that will validate the token. The only case in which this
/// can be [`None`] is if [`Self::scope`] is
/// [`ComputeClaimsScope::Admin`].
pub compute_id: Option<String>,
/// The scope of what the token authorizes.
pub scope: Option<ComputeClaimsScope>,
/// The recipient the token is intended for.
///
/// See [RFC 7519](https://www.rfc-editor.org/rfc/rfc7519#section-4.1.3) for
/// more information.
///
/// TODO: Remove the [`Option`] wrapper when control plane learns to send
/// the claim.
#[serde(rename = "aud")]
pub audience: Option<Vec<String>>,
pub compute_id: String,
}
/// Request of the /configure API

View File

@@ -46,30 +46,6 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
#[default]
NotPrewarmed,
Prewarming,
Completed,
Failed {
error: String,
},
}
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcOffloadState {
#[default]
NotOffloaded,
Offloading,
Completed,
Failed {
error: String,
},
}
/// Response of the /status API
#[derive(Serialize, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]

View File

@@ -172,15 +172,6 @@ pub struct ComputeSpec {
/// Hostname and the port of the otel collector. Leave empty to disable Postgres logs forwarding.
/// Example: config-shy-breeze-123-collector-monitoring.neon-telemetry.svc.cluster.local:10514
pub logs_export_host: Option<String>,
/// Address of endpoint storage service
pub endpoint_storage_addr: Option<String>,
/// JWT for authorizing requests to endpoint storage service
pub endpoint_storage_token: Option<String>,
/// If true, download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
pub prewarm_lfc_on_startup: bool,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.

View File

@@ -84,11 +84,6 @@
"value": "on",
"vartype": "bool"
},
{
"name": "prewarm_lfc_on_startup",
"value": "off",
"vartype": "bool"
},
{
"name": "neon.safekeepers",
"value": "127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501",

View File

@@ -16,7 +16,6 @@ pub struct Collector {
const NMETRICS: usize = 2;
static CLK_TCK_F64: Lazy<f64> = Lazy::new(|| {
// SAFETY: libc::sysconf is safe, it merely returns a value.
let long = unsafe { libc::sysconf(libc::_SC_CLK_TCK) };
if long == -1 {
panic!("sysconf(_SC_CLK_TCK) failed");

View File

@@ -43,21 +43,6 @@ pub struct NodeMetadata {
pub other: HashMap<String, serde_json::Value>,
}
/// PostHog integration config
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
project_id: String,
/// Server-side (private) API key
server_api_key: String,
/// Client-side (public) API key
client_api_key: String,
/// Private API URL
private_api_url: String,
/// Public API URL
public_api_url: String,
}
/// `pageserver.toml`
///
/// We use serde derive with `#[serde(default)]` to generate a deserializer
@@ -197,8 +182,6 @@ pub struct ConfigToml {
pub tracing: Option<Tracing>,
pub enable_tls_page_service_api: bool,
pub dev_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub posthog_config: Option<PostHogConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -676,7 +659,6 @@ impl Default for ConfigToml {
tracing: None,
enable_tls_page_service_api: false,
dev_mode: false,
posthog_config: None,
}
}
}

View File

@@ -295,9 +295,6 @@ pub struct TenantId(Id);
id_newtype!(TenantId);
/// If needed, reuse small string from proxy/src/types.rc
pub type EndpointId = String;
// A pair uniquely identifying Neon instance.
#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TenantTimelineId {

View File

@@ -14,7 +14,7 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -230,9 +230,6 @@ pub struct PageServerConf {
/// such as authentication requirements for HTTP and PostgreSQL APIs.
/// This is insecure and should only be used in development environments.
pub dev_mode: bool,
/// PostHog integration config
pub posthog_config: Option<PostHogConfig>,
}
/// Token for authentication to safekeepers
@@ -407,7 +404,6 @@ impl PageServerConf {
tracing,
enable_tls_page_service_api,
dev_mode,
posthog_config,
} = config_toml;
let mut conf = PageServerConf {
@@ -517,7 +513,6 @@ impl PageServerConf {
}
None => Vec::new(),
},
posthog_config,
};
// ------------------------------------------------------------

View File

@@ -1441,6 +1441,14 @@ impl DeltaLayerInner {
offset
}
pub fn iter<'a>(&'a self, ctx: &'a RequestContext) -> DeltaLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
@@ -1626,6 +1634,7 @@ pub(crate) mod test {
use crate::tenant::disk_btree::tests::TestDisk;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::vectored_blob_io::StreamingVectoredReadPlanner;
use crate::tenant::{TenantShard, Timeline};
/// Construct an index for a fictional delta layer and and then
@@ -2302,7 +2311,8 @@ pub(crate) mod test {
for batch_size in [1, 2, 4, 8, 3, 7, 13] {
println!("running with batch_size={batch_size} max_read_size={max_read_size}");
// Test if the batch size is correctly determined
let mut iter = delta_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = delta_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
let mut num_items = 0;
for _ in 0..3 {
iter.next_batch().await.unwrap();
@@ -2319,7 +2329,8 @@ pub(crate) mod test {
iter.key_values_batch.clear();
}
// Test if the result is correct
let mut iter = delta_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = delta_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
assert_delta_iter_equal(&mut iter, &test_deltas).await;
}
}

View File

@@ -157,7 +157,7 @@ mod tests {
.await
.unwrap();
let merge_iter = MergeIterator::create_for_testing(
let merge_iter = MergeIterator::create(
&[resident_layer_1.get_as_delta(&ctx).await.unwrap()],
&[],
&ctx,
@@ -182,7 +182,7 @@ mod tests {
result.extend(test_deltas1[90..100].iter().cloned());
assert_filter_iter_equal(&mut filter_iter, &result).await;
let merge_iter = MergeIterator::create_for_testing(
let merge_iter = MergeIterator::create(
&[resident_layer_1.get_as_delta(&ctx).await.unwrap()],
&[],
&ctx,

View File

@@ -684,6 +684,14 @@ impl ImageLayerInner {
}
}
pub(crate) fn iter<'a>(&'a self, ctx: &'a RequestContext) -> ImageLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub(crate) fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
@@ -1232,6 +1240,7 @@ mod test {
use crate::context::RequestContext;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::vectored_blob_io::StreamingVectoredReadPlanner;
use crate::tenant::{TenantShard, Timeline};
#[tokio::test]
@@ -1498,7 +1507,8 @@ mod test {
for batch_size in [1, 2, 4, 8, 3, 7, 13] {
println!("running with batch_size={batch_size} max_read_size={max_read_size}");
// Test if the batch size is correctly determined
let mut iter = img_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = img_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
let mut num_items = 0;
for _ in 0..3 {
iter.next_batch().await.unwrap();
@@ -1515,7 +1525,8 @@ mod test {
iter.key_values_batch.clear();
}
// Test if the result is correct
let mut iter = img_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = img_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
assert_img_iter_equal(&mut iter, &test_imgs, Lsn(0x10)).await;
}
}

View File

@@ -19,6 +19,14 @@ pub(crate) enum LayerRef<'a> {
}
impl<'a> LayerRef<'a> {
#[allow(dead_code)]
fn iter(self, ctx: &'a RequestContext) -> LayerIterRef<'a> {
match self {
Self::Image(x) => LayerIterRef::Image(x.iter(ctx)),
Self::Delta(x) => LayerIterRef::Delta(x.iter(ctx)),
}
}
fn iter_with_options(
self,
ctx: &'a RequestContext,
@@ -314,28 +322,6 @@ impl MergeIteratorItem for ((Key, Lsn, Value), Arc<PersistentLayerKey>) {
}
impl<'a> MergeIterator<'a> {
#[cfg(test)]
pub(crate) fn create_for_testing(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
) -> Self {
Self::create_with_options(deltas, images, ctx, 1024 * 8192, 1024)
}
/// Create a new merge iterator with custom options.
///
/// Adjust `max_read_size` and `max_batch_size` to trade memory usage for performance. The size should scale
/// with the number of layers to compact. If there are a lot of layers, consider reducing the values, so that
/// the buffer does not take too much memory.
///
/// The default options for L0 compactions are:
/// - max_read_size: 1024 * 8192 (8MB)
/// - max_batch_size: 1024
///
/// The default options for gc-compaction are:
/// - max_read_size: 128 * 8192 (1MB)
/// - max_batch_size: 128
pub fn create_with_options(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
@@ -365,6 +351,14 @@ impl<'a> MergeIterator<'a> {
}
}
pub fn create(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
) -> Self {
Self::create_with_options(deltas, images, ctx, 1024 * 8192, 1024)
}
pub(crate) async fn next_inner<R: MergeIteratorItem>(&mut self) -> anyhow::Result<Option<R>> {
while let Some(mut iter) = self.heap.peek_mut() {
if !iter.is_loaded() {
@@ -483,7 +477,7 @@ mod tests {
let resident_layer_2 = produce_delta_layer(&tenant, &tline, test_deltas2.clone(), &ctx)
.await
.unwrap();
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_2.get_as_delta(&ctx).await.unwrap(),
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
@@ -555,7 +549,7 @@ mod tests {
let resident_layer_3 = produce_delta_layer(&tenant, &tline, test_deltas3.clone(), &ctx)
.await
.unwrap();
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
resident_layer_2.get_as_delta(&ctx).await.unwrap(),
@@ -676,7 +670,7 @@ mod tests {
// Test with different layer order for MergeIterator::create to ensure the order
// is stable.
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_4.get_as_delta(&ctx).await.unwrap(),
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
@@ -688,7 +682,7 @@ mod tests {
);
assert_merge_iter_equal(&mut merge_iter, &expect).await;
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
resident_layer_4.get_as_delta(&ctx).await.unwrap(),

View File

@@ -1994,13 +1994,7 @@ impl Timeline {
let l = l.get_as_delta(ctx).await.map_err(CompactionError::Other)?;
deltas.push(l);
}
MergeIterator::create_with_options(
&deltas,
&[],
ctx,
1024 * 8192, /* 8 MiB buffer per layer iterator */
1024,
)
MergeIterator::create(&deltas, &[], ctx)
};
// This iterator walks through all keys and is needed to calculate size used by each key
@@ -2834,7 +2828,7 @@ impl Timeline {
Ok(())
}
/// Check to bail out of gc compaction early if it would use too much memory.
/// Check if the memory usage is within the limit.
async fn check_memory_usage(
self: &Arc<Self>,
layer_selection: &[Layer],
@@ -2847,8 +2841,7 @@ impl Timeline {
let layer_desc = layer.layer_desc();
if layer_desc.is_delta() {
// Delta layers at most have 1MB buffer; 3x to make it safe (there're deltas as large as 16KB).
// Scale it by target_layer_size_bytes so that tests can pass (some tests, e.g., `test_pageserver_gc_compaction_preempt
// use 3MB layer size and we need to account for that).
// Multiply the layer size so that tests can pass.
estimated_memory_usage_mb +=
3.0 * (layer_desc.file_size / target_layer_size_bytes) as f64;
num_delta_layers += 1;

View File

@@ -14,6 +14,8 @@
use std::fs::File;
use std::io::{Error, ErrorKind};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
@@ -97,7 +99,7 @@ impl VirtualFile {
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
#[cfg_attr(not(target_os = "linux"), allow(unused_mut))] mut open_options: OpenOptions,
open_options: &OpenOptions,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let mode = get_io_mode();
@@ -110,16 +112,21 @@ impl VirtualFile {
#[cfg(target_os = "linux")]
(IoMode::DirectRw, _) => true,
};
if set_o_direct {
let open_options = open_options.clone();
let open_options = if set_o_direct {
#[cfg(target_os = "linux")]
{
open_options = open_options.custom_flags(nix::libc::O_DIRECT);
let mut open_options = open_options;
open_options.custom_flags(nix::libc::O_DIRECT);
open_options
}
#[cfg(not(target_os = "linux"))]
unreachable!(
"O_DIRECT is not supported on this platform, IoMode's that result in set_o_direct=true shouldn't even be defined"
);
}
} else {
open_options
};
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile { inner, _mode: mode })
}
@@ -523,7 +530,7 @@ impl VirtualFileInner {
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true).clone(), ctx).await
}
/// Open a file with given options.
@@ -551,11 +558,10 @@ impl VirtualFileInner {
// It would perhaps be nicer to check just for the read and write flags
// explicitly, but OpenOptions doesn't contain any functions to read flags,
// only to set them.
let reopen_options = open_options
.clone()
.create(false)
.create_new(false)
.truncate(false);
let mut reopen_options = open_options.clone();
reopen_options.create(false);
reopen_options.create_new(false);
reopen_options.truncate(false);
let vfile = VirtualFileInner {
handle: RwLock::new(handle),
@@ -1301,7 +1307,7 @@ mod tests {
opts: OpenOptions,
ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error> {
let vf = VirtualFile::open_with_options_v2(&path, opts, ctx).await?;
let vf = VirtualFile::open_with_options_v2(&path, &opts, ctx).await?;
Ok(MaybeVirtualFile::VirtualFile(vf))
}
}
@@ -1368,7 +1374,7 @@ mod tests {
let _ = file_a.read_string_at(0, 1, &ctx).await.unwrap_err();
// Close the file and re-open for reading
let mut file_a = A::open(path_a, OpenOptions::new().read(true), &ctx).await?;
let mut file_a = A::open(path_a, OpenOptions::new().read(true).to_owned(), &ctx).await?;
// cannot write to a file opened in read-only mode
let _ = file_a
@@ -1387,7 +1393,8 @@ mod tests {
.read(true)
.write(true)
.create(true)
.truncate(true),
.truncate(true)
.to_owned(),
&ctx,
)
.await?;
@@ -1405,7 +1412,12 @@ mod tests {
let mut vfiles = Vec::new();
for _ in 0..100 {
let mut vfile = A::open(path_b.clone(), OpenOptions::new().read(true), &ctx).await?;
let mut vfile = A::open(
path_b.clone(),
OpenOptions::new().read(true).to_owned(),
&ctx,
)
.await?;
assert_eq!("FOOBAR", vfile.read_string_at(0, 6, &ctx).await?);
vfiles.push(vfile);
}
@@ -1454,7 +1466,7 @@ mod tests {
for _ in 0..VIRTUAL_FILES {
let f = VirtualFileInner::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
OpenOptions::new().read(true).clone(),
&ctx,
)
.await?;

View File

@@ -1,7 +1,6 @@
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
use std::os::fd::OwnedFd;
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use super::io_engine::IoEngine;
@@ -44,7 +43,7 @@ impl OpenOptions {
self.write
}
pub fn read(mut self, read: bool) -> Self {
pub fn read(&mut self, read: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.read(read);
@@ -57,7 +56,7 @@ impl OpenOptions {
self
}
pub fn write(mut self, write: bool) -> Self {
pub fn write(&mut self, write: bool) -> &mut OpenOptions {
self.write = write;
match &mut self.inner {
Inner::StdFs(x) => {
@@ -71,7 +70,7 @@ impl OpenOptions {
self
}
pub fn create(mut self, create: bool) -> Self {
pub fn create(&mut self, create: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create(create);
@@ -84,7 +83,7 @@ impl OpenOptions {
self
}
pub fn create_new(mut self, create_new: bool) -> Self {
pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create_new(create_new);
@@ -97,7 +96,7 @@ impl OpenOptions {
self
}
pub fn truncate(mut self, truncate: bool) -> Self {
pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.truncate(truncate);
@@ -125,8 +124,10 @@ impl OpenOptions {
}
}
}
}
pub fn mode(mut self, mode: u32) -> Self {
impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
fn mode(&mut self, mode: u32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.mode(mode);
@@ -139,7 +140,7 @@ impl OpenOptions {
self
}
pub fn custom_flags(mut self, flags: i32) -> Self {
fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.custom_flags(flags);

View File

@@ -150,7 +150,7 @@ NeonWALReaderFree(NeonWALReader *state)
* fetched from timeline 'tli'.
*
* Returns NEON_WALREAD_SUCCESS if succeeded, NEON_WALREAD_ERROR if an error
* occurs, in which case 'err' has the description. Error always closes remote
* occurs, in which case 'err' has the desciption. Error always closes remote
* connection, if there was any, so socket subscription should be removed.
*
* NEON_WALREAD_WOULDBLOCK means caller should obtain socket to wait for with

View File

@@ -1989,14 +1989,8 @@ neon_start_unlogged_build(SMgrRelation reln)
neon_log(ERROR, "unknown relpersistence '%c'", reln->smgr_relpersistence);
}
#if PG_MAJORVERSION_NUM >= 17
/*
* We have to disable this check for pg14-16 because sorted build of GIST index requires
* to perform unlogged build several times
*/
if (smgrnblocks(reln, MAIN_FORKNUM) != 0)
neon_log(ERROR, "cannot perform unlogged index build, index is not empty ");
#endif
unlogged_build_rel = reln;
unlogged_build_phase = UNLOGGED_BUILD_PHASE_1;

View File

@@ -124,7 +124,6 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
}
else
{
wp->safekeepers_generation = INVALID_GENERATION;
host = wp->config->safekeepers_list;
}
wp_log(LOG, "safekeepers_generation=%u", wp->safekeepers_generation);
@@ -757,7 +756,7 @@ UpdateMemberSafekeeperPtr(WalProposer *wp, Safekeeper *sk)
{
SafekeeperId *sk_id = &wp->mconf.members.m[i];
if (sk_id->node_id == sk->greetResponse.nodeId)
if (wp->mconf.members.m[i].node_id == sk->greetResponse.nodeId)
{
/*
* If mconf or list of safekeepers to connect to changed (the
@@ -782,7 +781,7 @@ UpdateMemberSafekeeperPtr(WalProposer *wp, Safekeeper *sk)
{
SafekeeperId *sk_id = &wp->mconf.new_members.m[i];
if (sk_id->node_id == sk->greetResponse.nodeId)
if (wp->mconf.new_members.m[i].node_id == sk->greetResponse.nodeId)
{
if (wp->new_members_safekeepers[i] != NULL && wp->new_members_safekeepers[i] != sk)
{
@@ -1072,6 +1071,7 @@ RecvVoteResponse(Safekeeper *sk)
/* ready for elected message */
sk->state = SS_WAIT_ELECTED;
wp->n_votes++;
/* Are we already elected? */
if (wp->state == WPS_CAMPAIGN)
{

View File

@@ -845,6 +845,9 @@ typedef struct WalProposer
/* timeline globally starts at this LSN */
XLogRecPtr timelineStartLsn;
/* number of votes collected from safekeepers */
int n_votes;
/* number of successful connections over the lifetime of walproposer */
int n_connected;

View File

@@ -409,22 +409,14 @@ impl JwkCacheEntryLock {
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
}
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
JwtClaimsError::JwtTokenNotYetReadyToUse,
));
}
}
@@ -542,10 +534,10 @@ struct JwtPayload<'a> {
#[serde(rename = "aud", default)]
audience: OneOrMany,
/// Expiration - Time after which the JWT expires
#[serde(rename = "exp", deserialize_with = "numeric_date_opt", default)]
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time before which the JWT is not valid
#[serde(rename = "nbf", deserialize_with = "numeric_date_opt", default)]
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
@@ -617,15 +609,8 @@ impl<'de> Deserialize<'de> for OneOrMany {
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
<Option<u64>>::deserialize(d)?
.map(|t| {
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_secs(t))
.ok_or_else(|| {
serde::de::Error::custom(format_args!("timestamp out of bounds: {t}"))
})
})
.transpose()
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
}
struct JwkRenewalPermit<'a> {
@@ -761,11 +746,11 @@ pub enum JwtClaimsError {
#[error("invalid JWT token audience")]
InvalidJwtTokenAudience,
#[error("JWT token has expired (exp={0})")]
JwtTokenHasExpired(u64),
#[error("JWT token has expired")]
JwtTokenHasExpired,
#[error("JWT token is not yet ready to use (nbf={0})")]
JwtTokenNotYetReadyToUse(u64),
#[error("JWT token is not yet ready to use")]
JwtTokenNotYetReadyToUse,
}
#[allow(dead_code, reason = "Debug use only")]
@@ -1248,14 +1233,14 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
"nbf": now + 60,
"aud": "neon",
}},
error: JwtClaimsError::JwtTokenNotYetReadyToUse(now + 60),
error: JwtClaimsError::JwtTokenNotYetReadyToUse,
},
Test {
body: json! {{
"exp": now - 60,
"aud": ["neon"],
}},
error: JwtClaimsError::JwtTokenHasExpired(now - 60),
error: JwtClaimsError::JwtTokenHasExpired,
},
Test {
body: json! {{

View File

@@ -32,6 +32,12 @@ pub(crate) enum ComputeUserInfoParseError {
option: EndpointId,
},
#[error(
"Common name inferred from SNI ('{}') is not known",
.cn,
)]
UnknownCommonName { cn: String },
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(EndpointId),
}
@@ -60,15 +66,22 @@ impl ComputeUserInfoMaybeEndpoint {
}
}
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
let (subdomain, common_name) = sni.split_once('.')?;
pub(crate) fn endpoint_sni(
sni: &str,
common_names: &HashSet<String>,
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
if !common_names.contains(common_name) {
return None;
return Err(ComputeUserInfoParseError::UnknownCommonName {
cn: common_name.into(),
});
}
if subdomain == SERVERLESS_DRIVER_SNI {
return None;
return Ok(None);
}
Some(EndpointId::from(subdomain))
Ok(Some(EndpointId::from(subdomain)))
}
impl ComputeUserInfoMaybeEndpoint {
@@ -100,8 +113,15 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain =
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
@@ -404,34 +424,21 @@ mod tests {
}
#[test]
fn parse_unknown_sni() {
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert!(info.endpoint_id.is_none());
}
#[test]
fn parse_unknown_sni_with_options() {
let options = StartupMessageParams::new([
("user", "john_doe"),
("options", "endpoint=foo-bar-baz-1234"),
]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]

View File

@@ -24,6 +24,9 @@ pub(crate) enum HandshakeError {
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
@@ -39,6 +42,10 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
@@ -139,7 +146,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// try parse endpoint
let ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names));
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
}
@@ -154,8 +161,10 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(conn_info.server_name())
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream {
framed: Framed {

View File

@@ -98,7 +98,8 @@ fn generate_tls_config<'a>(
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?;
let cert_resolver = CertResolver::new(key, vec![cert])?;
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = cert_resolver.get_common_names();

View File

@@ -41,7 +41,7 @@ use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
pub(crate) const EXT_VERSION: &str = "0.3.1";
pub(crate) const EXT_VERSION: &str = "0.3.0";
pub(crate) const EXT_SCHEMA: &str = "auth";
#[derive(Clone)]

View File

@@ -199,7 +199,8 @@ fn get_conn_info(
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
endpoint_sni(hostname, &tls.common_names)?
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')

View File

@@ -5,7 +5,6 @@ use anyhow::{Context, bail};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::sign::CertifiedKey;
use x509_cert::der::{Reader, SliceReader};
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
@@ -26,8 +25,10 @@ pub fn configure_tls(
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -39,8 +40,11 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver
.add_cert_path(&key_path.to_string_lossy(), &cert_path.to_string_lossy())?;
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
@@ -79,42 +83,92 @@ pub fn configure_tls(
})
}
#[derive(Debug)]
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
fn parse_new(key_path: &str, cert_path: &str) -> anyhow::Result<Self> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
Self::new(priv_key, cert_chain)
pub fn new() -> Self {
Self::default()
}
pub fn new(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<Self> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let mut certs = HashMap::new();
let default = (cert.clone(), tls_server_end_point);
certs.insert(common_name, (cert, tls_server_end_point));
Ok(Self { certs, default })
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
fn add_cert_path(&mut self, key_path: &str, cert_path: &str) -> anyhow::Result<()> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
self.add_cert(priv_key, cert_chain)
}
fn add_cert(
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
@@ -123,82 +177,12 @@ impl CertResolver {
}
}
fn parse_key_cert(
key_path: &str,
cert_path: &str,
) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
)
})?
};
Ok((priv_key, cert_chain))
}
fn process_key_cert(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
Ok((common_name, cert, tls_server_end_point))
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(self.resolve(client_hello.server_name()).0)
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
@@ -206,7 +190,7 @@ impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
@@ -216,17 +200,12 @@ impl CertResolver {
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return cert.clone();
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
// The customer has some custom DNS mapping - just return
// a default certificate.
//
// This will error if the customer uses anything stronger
// than sslmode=require. That's a choice they can make.
return self.default.clone();
return None;
}
}
} else {

View File

@@ -5181,8 +5181,7 @@ impl Service {
}
// We don't expect any new_shard_count shards to exist here, but drop them just in case
tenants
.retain(|id, s| !(id.tenant_id == *tenant_id && s.shard.count == *new_shard_count));
tenants.retain(|_id, s| s.shard.count != *new_shard_count);
detach_locations
};

View File

@@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::SystemTime;
use futures_util::StreamExt;
@@ -55,7 +56,7 @@ impl TimelineAnalysis {
pub(crate) async fn branch_cleanup_and_check_errors(
remote_client: &GenericRemoteStorage,
id: &TenantShardTimelineId,
tenant_objects: &mut TenantObjectListing,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
s3_active_branch: Option<&BranchData>,
console_branch: Option<BranchData>,
s3_data: Option<RemoteTimelineBlobData>,
@@ -150,7 +151,11 @@ pub(crate) async fn branch_cleanup_and_check_errors(
))
}
if !tenant_objects.check_ref(id.timeline_id, &layer, &metadata) {
if !tenant_objects
.lock()
.await
.check_ref(id.timeline_id, &layer, &metadata)
{
let path = remote_layer_path(
&id.tenant_shard_id.tenant_id,
&id.timeline_id,

View File

@@ -73,8 +73,12 @@ enum Command {
node_kind: NodeKind,
#[arg(short, long, default_value_t = false)]
json: bool,
/// If provided, only these tenants will be listed from the remote storage.
#[arg(long = "tenant-id", num_args = 0..)]
tenant_ids: Vec<TenantShardId>,
/// If provided, we will list all tenants, but then filter with the prefix.
#[arg(long = "tenant-id-prefix")]
tenant_id_prefix: Option<TenantId>,
#[arg(long = "post", default_value_t = false)]
post_to_storcon: bool,
#[arg(long, default_value = None)]
@@ -178,6 +182,7 @@ async fn main() -> anyhow::Result<()> {
Command::ScanMetadata {
json,
tenant_ids,
tenant_id_prefix,
node_kind,
post_to_storcon,
dump_db_connstr,
@@ -186,6 +191,9 @@ async fn main() -> anyhow::Result<()> {
verbose,
} => {
if let NodeKind::Safekeeper = node_kind {
if tenant_id_prefix.is_some() {
bail!("`tenant_id_prefix` is not supported for safekeeper node_kind");
}
let db_or_list = match (timeline_lsns, dump_db_connstr) {
(Some(timeline_lsns), _) => {
let timeline_lsns = serde_json::from_str(&timeline_lsns)
@@ -227,6 +235,7 @@ async fn main() -> anyhow::Result<()> {
bucket_config,
controller_client.as_ref(),
tenant_ids,
tenant_id_prefix,
json,
post_to_storcon,
verbose,
@@ -338,6 +347,7 @@ pub async fn run_cron_job(
bucket_config,
controller_client,
Vec::new(),
None,
true,
post_to_storcon,
false, // default to non-verbose mode
@@ -384,10 +394,12 @@ pub async fn pageserver_physical_gc_cmd(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn scan_pageserver_metadata_cmd(
bucket_config: BucketConfig,
controller_client: Option<&control_api::Client>,
tenant_shard_ids: Vec<TenantShardId>,
tenant_id_prefix: Option<TenantId>,
json: bool,
post_to_storcon: bool,
verbose: bool,
@@ -398,7 +410,14 @@ pub async fn scan_pageserver_metadata_cmd(
"Posting pageserver scan health status to storage controller requires `--controller-api` and `--controller-jwt` to run"
));
}
match scan_pageserver_metadata(bucket_config.clone(), tenant_shard_ids, verbose).await {
match scan_pageserver_metadata(
bucket_config.clone(),
tenant_shard_ids,
tenant_id_prefix,
verbose,
)
.await
{
Err(e) => {
tracing::error!("Failed: {e}");
Err(e)

View File

@@ -1,5 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use futures::SinkExt;
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::remote_layer_path;
use pageserver_api::controller_api::MetadataHealthUpdateRequest;
@@ -7,6 +9,7 @@ use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use serde::Serialize;
use tracing::{Instrument, info_span};
use utils::generation::Generation;
use utils::id::TenantId;
use utils::shard::ShardCount;
@@ -14,10 +17,12 @@ use crate::checks::{
BlobDataParseResult, RemoteTimelineBlobData, TenantObjectListing, TimelineAnalysis,
branch_cleanup_and_check_errors, list_timeline_blobs,
};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::metadata_stream::{
stream_tenant_timelines, stream_tenants, stream_tenants_maybe_prefix,
};
use crate::{BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, init_remote};
#[derive(Serialize, Default)]
#[derive(Serialize, Default, Clone)]
pub struct MetadataSummary {
tenant_count: usize,
timeline_count: usize,
@@ -102,13 +107,13 @@ impl MetadataSummary {
format!(
"Tenants: {}
Timelines: {}
Timeline-shards: {}
With errors: {}
With warnings: {}
With orphan layers: {}
Index versions: {version_summary}
",
Timelines: {}
Timeline-shards: {}
With errors: {}
With warnings: {}
With orphan layers: {}
Index versions: {version_summary}
",
self.tenant_count,
self.timeline_count,
self.timeline_shard_count,
@@ -138,24 +143,243 @@ Index versions: {version_summary}
pub async fn scan_pageserver_metadata(
bucket_config: BucketConfig,
tenant_ids: Vec<TenantShardId>,
tenant_id_prefix: Option<TenantId>,
verbose: bool,
) -> anyhow::Result<MetadataSummary> {
let (remote_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?;
let tenants = if tenant_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&remote_client, &target))
} else {
futures::future::Either::Right(futures::stream::iter(tenant_ids.into_iter().map(Ok)))
};
if !tenant_ids.is_empty() && tenant_id_prefix.is_some() {
anyhow::bail!("`tenant_id_prefix` is not supported when `tenant_ids` is provided");
}
// How many tenants to process in parallel. We need to be mindful of pageservers
// accessing the same per tenant prefixes, so use a lower setting than pageservers.
const CONCURRENCY: usize = 32;
let (mut list_tenants_tx, list_tenants_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let list_tenants = tokio::spawn(async move {
let mut cnt = 0;
if tenant_ids.is_empty() {
if let Some(tenant_id_prefix) = tenant_id_prefix {
let stream = stream_tenants_maybe_prefix(
&remote_client_inner,
&target_inner,
Some(tenant_id_prefix.to_string()),
);
let mut stream = Box::pin(stream);
while let Some(tenant) = stream.next().await {
let tenant = tenant?;
list_tenants_tx.send(tenant).await?;
cnt += 1;
}
} else {
let stream = stream_tenants(&remote_client_inner, &target_inner);
let mut stream = Box::pin(stream);
while let Some(tenant) = stream.next().await {
let tenant = tenant?;
list_tenants_tx.send(tenant).await?;
cnt += 1;
}
}
} else {
for tenant_id in tenant_ids {
list_tenants_tx.send(tenant_id).await?;
cnt += 1;
}
}
tracing::info!("list_tenants: collected {} tenants", cnt);
Ok::<_, anyhow::Error>(())
});
// Generate a stream of TenantTimelineId
let timelines = tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, t));
let timelines = timelines.try_buffered(CONCURRENCY);
let timelines = timelines.try_flatten();
let (mut list_timelines_tx, list_timelines_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let list_timelines = tokio::spawn(async move {
let stream = list_tenants_rx
.map(|tenant_id| {
stream_tenant_timelines(&remote_client_inner, &target_inner, tenant_id)
})
.buffered(8)
.try_flatten();
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
list_timelines_tx.send(item).await?;
}
Ok::<_, anyhow::Error>(())
});
let (mut read_timelines_tx, read_timelines_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let read_timelines = tokio::spawn(async move {
let stream = list_timelines_rx
.map(|ttid| report_on_timeline(&remote_client_inner, &target_inner, ttid))
.buffered(32);
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
read_timelines_tx.send(item).await?;
}
Ok::<_, anyhow::Error>(())
});
let summary = Arc::new(tokio::sync::Mutex::new(MetadataSummary::new()));
let summary_inner = summary.clone();
let (mut consolidate_tenants_tx, consolidate_tenants_rx) = futures::channel::mpsc::channel(32);
let consolidate_tenants = tokio::spawn(async move {
// We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different
// shards in the same tenant might refer to one anothers' keys if a shard split has happened.
let mut tenant_id = None;
let mut tenant_objects = TenantObjectListing::default();
let mut tenant_timeline_results = Vec::new();
// Iterate through all the timeline results. These are in key-order, so
// all results for the same tenant will be adjacent. We accumulate these,
// and then call `analyze_tenant` to flush, when we see the next tenant ID.
let mut highest_shard_count = ShardCount::MIN;
let mut read_timelines_rx = read_timelines_rx;
while let Some(i) = read_timelines_rx.next().await {
let (ttid, data) = i;
{
let mut guard = summary_inner.lock().await;
guard.update_data(&data);
}
match tenant_id {
Some(prev_tenant_id) => {
if prev_tenant_id != ttid.tenant_shard_id.tenant_id {
// New tenant: analyze this tenant's timelines, clear accumulated tenant_timeline_results
let tenant_objects = std::mem::take(&mut tenant_objects);
let timelines = std::mem::take(&mut tenant_timeline_results);
analyze_tenant(
summary_inner.clone(),
Arc::new(tokio::sync::Mutex::new(tenant_objects)),
timelines,
highest_shard_count,
&mut consolidate_tenants_tx,
)
.await?;
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = ttid.tenant_shard_id.shard_count;
} else {
highest_shard_count =
highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
None => {
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
match &data.blob_data {
BlobDataParseResult::Parsed {
index_part: _,
index_part_generation: _index_part_generation,
s3_layers,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
BlobDataParseResult::Relic => (),
BlobDataParseResult::Incorrect {
errors: _,
s3_layers,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
}
tenant_timeline_results.push((ttid, data));
}
if !tenant_timeline_results.is_empty() {
analyze_tenant(
summary_inner.clone(),
Arc::new(tokio::sync::Mutex::new(tenant_objects)),
tenant_timeline_results,
highest_shard_count,
&mut consolidate_tenants_tx,
)
.await?;
}
Ok::<_, anyhow::Error>(())
});
let remote_client_inner = remote_client.clone();
let summary_inner = summary.clone();
let analyze_tenants = tokio::spawn(async move {
let stream = consolidate_tenants_rx
.map(|(ttid, tenant_objects, data)| {
let remote_client_inner = remote_client_inner.clone();
async move {
let generation = if let BlobDataParseResult::Parsed {
index_part: _,
index_part_generation,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
Some(*index_part_generation)
} else {
None
};
let res = branch_cleanup_and_check_errors(
&remote_client_inner,
&ttid,
tenant_objects.clone(),
None,
None,
Some(data),
)
.await;
(ttid, tenant_objects.clone(), generation, res)
}
})
.buffered(32);
let mut last_tenant = None;
let mut last_tenant_objects = None;
let mut timeline_generations = HashMap::new();
let mut stream = Box::pin(stream);
while let Some((ttid, tenant_objects, generation, res)) = stream.next().await {
if last_tenant != Some(ttid) {
if let Some(tenant_id) = last_tenant {
let timeline_generations = std::mem::take(&mut timeline_generations);
identify_orphans(
tenant_id.tenant_shard_id.tenant_id,
last_tenant_objects.take().unwrap(),
summary_inner.clone(),
&timeline_generations,
)
.await;
}
last_tenant = Some(ttid);
last_tenant_objects = Some(tenant_objects);
}
if let Some(generation) = generation {
timeline_generations.insert(ttid, generation);
}
{
let mut guard = summary_inner.lock().await;
guard.update_analysis(&ttid, &res, verbose);
}
}
if let Some(tenant_id) = last_tenant {
identify_orphans(
tenant_id.tenant_shard_id.tenant_id,
last_tenant_objects.take().unwrap(),
summary_inner.clone(),
&timeline_generations,
)
.await;
}
Ok::<_, anyhow::Error>(())
});
// Generate a stream of S3TimelineBlobData
async fn report_on_timeline(
@@ -163,93 +387,94 @@ pub async fn scan_pageserver_metadata(
target: &RootTarget,
ttid: TenantShardTimelineId,
) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> {
tracing::info!("listing blobs for timeline: {}", ttid);
let data = list_timeline_blobs(remote_client, ttid, target).await?;
Ok((ttid, data))
}
let timelines = timelines.map_ok(|ttid| report_on_timeline(&remote_client, &target, ttid));
let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY));
// We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different
// shards in the same tenant might refer to one anothers' keys if a shard split has happened.
let mut tenant_id = None;
let mut tenant_objects = TenantObjectListing::default();
let mut tenant_timeline_results = Vec::new();
// DO NOT call any long-running tasks in this function; always route them through the channel and let
// other tokio tasks handle them.
async fn analyze_tenant(
remote_client: &GenericRemoteStorage,
tenant_id: TenantId,
summary: &mut MetadataSummary,
mut tenant_objects: TenantObjectListing,
summary: Arc<tokio::sync::Mutex<MetadataSummary>>,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
timelines: Vec<(TenantShardTimelineId, RemoteTimelineBlobData)>,
highest_shard_count: ShardCount,
verbose: bool,
) {
summary.tenant_count += 1;
let mut timeline_ids = HashSet::new();
let mut timeline_generations = HashMap::new();
for (ttid, data) in timelines {
async {
if ttid.tenant_shard_id.shard_count == highest_shard_count {
// Only analyze `TenantShardId`s with highest shard count.
// Stash the generation of each timeline, for later use identifying orphan layers
if let BlobDataParseResult::Parsed {
index_part,
index_part_generation,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
if index_part.deleted_at.is_some() {
// skip deleted timeline.
tracing::info!(
"Skip analysis of {} b/c timeline is already deleted",
ttid
);
return;
}
timeline_generations.insert(ttid, *index_part_generation);
}
// Apply checks to this timeline shard's metadata, and in the process update `tenant_objects`
// reference counts for layers across the tenant.
let analysis = branch_cleanup_and_check_errors(
remote_client,
&ttid,
&mut tenant_objects,
None,
None,
Some(data),
)
.await;
summary.update_analysis(&ttid, &analysis, verbose);
timeline_ids.insert(ttid.timeline_id);
} else {
tracing::info!(
"Skip analysis of {} b/c a lower shard count than {}",
ttid,
highest_shard_count.0,
);
}
}
.instrument(
info_span!("analyze-timeline", shard = %ttid.tenant_shard_id.shard_slug(), timeline = %ttid.timeline_id),
)
.await
output_tx: &mut futures::channel::mpsc::Sender<(
TenantShardTimelineId,
Arc<tokio::sync::Mutex<TenantObjectListing>>,
RemoteTimelineBlobData,
)>,
) -> anyhow::Result<()> {
{
let mut guard = summary.lock().await;
guard.tenant_count += 1;
}
summary.timeline_count += timeline_ids.len();
let mut timeline_ids = HashSet::new();
for (ttid, data) in timelines {
async {
if ttid.tenant_shard_id.shard_count == highest_shard_count {
// Only analyze `TenantShardId`s with highest shard count.
// Stash the generation of each timeline, for later use identifying orphan layers
if let BlobDataParseResult::Parsed {
index_part,
index_part_generation: _,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
if index_part.deleted_at.is_some() {
// skip deleted timeline.
tracing::info!("Skip analysis of {} b/c timeline is already deleted", ttid);
return Ok(());
}
}
// Apply checks to this timeline shard's metadata, and in the process update `tenant_objects`
// reference counts for layers across the tenant.
output_tx.send((ttid, tenant_objects.clone(), data)).await?;
timeline_ids.insert(ttid.timeline_id);
} else {
tracing::info!(
"Skip analysis of {} b/c a lower shard count than {}",
ttid,
highest_shard_count.0,
);
}
Ok::<_, anyhow::Error>(())
}.instrument(
info_span!("analyze-timeline", shard = %ttid.tenant_shard_id.shard_slug(), timeline = %ttid.timeline_id),
)
.await?;
}
{
let mut guard = summary.lock().await;
guard.timeline_count += timeline_ids.len();
}
Ok(())
}
async fn identify_orphans(
tenant_id: TenantId,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
summary: Arc<tokio::sync::Mutex<MetadataSummary>>,
timeline_generations: &HashMap<TenantShardTimelineId, Generation>,
) {
// Identifying orphan layers must be done on a tenant-wide basis, because individual
// shards' layers may be referenced by other shards.
//
// Orphan layers are not a corruption, and not an indication of a problem. They are just
// consuming some space in remote storage, and may be cleaned up at leisure.
for (shard_index, timeline_id, layer_file, generation) in tenant_objects.get_orphans() {
let orphans = { tenant_objects.lock().await.get_orphans() };
for (shard_index, timeline_id, layer_file, generation) in orphans {
let ttid = TenantShardTimelineId {
tenant_shard_id: TenantShardId {
tenant_id,
@@ -279,83 +504,20 @@ pub async fn scan_pageserver_metadata(
tracing::info!("Orphan layer detected: {orphan_path}");
summary.notify_timeline_orphan(&ttid);
{
let mut guard = summary.lock().await;
guard.notify_timeline_orphan(&ttid);
}
}
}
// Iterate through all the timeline results. These are in key-order, so
// all results for the same tenant will be adjacent. We accumulate these,
// and then call `analyze_tenant` to flush, when we see the next tenant ID.
let mut summary = MetadataSummary::new();
let mut highest_shard_count = ShardCount::MIN;
while let Some(i) = timelines.next().await {
let (ttid, data) = i?;
summary.update_data(&data);
// TODO: bail out early if any of the tasks fail
list_tenants.await??;
list_timelines.await??;
read_timelines.await??;
consolidate_tenants.await??;
analyze_tenants.await??;
match tenant_id {
Some(prev_tenant_id) => {
if prev_tenant_id != ttid.tenant_shard_id.tenant_id {
// New tenant: analyze this tenant's timelines, clear accumulated tenant_timeline_results
let tenant_objects = std::mem::take(&mut tenant_objects);
let timelines = std::mem::take(&mut tenant_timeline_results);
analyze_tenant(
&remote_client,
prev_tenant_id,
&mut summary,
tenant_objects,
timelines,
highest_shard_count,
verbose,
)
.instrument(info_span!("analyze-tenant", tenant = %prev_tenant_id))
.await;
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = ttid.tenant_shard_id.shard_count;
} else {
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
None => {
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
match &data.blob_data {
BlobDataParseResult::Parsed {
index_part: _,
index_part_generation: _index_part_generation,
s3_layers,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
BlobDataParseResult::Relic => (),
BlobDataParseResult::Incorrect {
errors: _,
s3_layers,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
}
tenant_timeline_results.push((ttid, data));
}
if !tenant_timeline_results.is_empty() {
let tenant_id = tenant_id.expect("Must be set if results are present");
analyze_tenant(
&remote_client,
tenant_id,
&mut summary,
tenant_objects,
tenant_timeline_results,
highest_shard_count,
verbose,
)
.instrument(info_span!("analyze-tenant", tenant = %tenant_id))
.await;
}
Ok(summary)
let summary = summary.lock().await;
Ok(summary.clone())
}

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import urllib.parse
from enum import StrEnum
from typing import TYPE_CHECKING, final
import requests
@@ -10,23 +9,11 @@ from requests.auth import AuthBase
from typing_extensions import override
from fixtures.log_helper import log
from fixtures.utils import wait_until
if TYPE_CHECKING:
from requests import PreparedRequest
COMPUTE_AUDIENCE = "compute"
"""
The value to place in the `aud` claim.
"""
@final
class ComputeClaimsScope(StrEnum):
ADMIN = "admin"
@final
class BearerAuth(AuthBase):
"""
@@ -63,35 +50,6 @@ class EndpointHttpClient(requests.Session):
res.raise_for_status()
return res.json()
def prewarm_lfc_status(self) -> dict[str, str]:
res = self.get(f"http://localhost:{self.external_port}/lfc/prewarm")
res.raise_for_status()
json: dict[str, str] = res.json()
return json
def prewarm_lfc(self):
self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status()
def prewarmed():
json = self.prewarm_lfc_status()
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(prewarmed)
def offload_lfc(self):
url = f"http://localhost:{self.external_port}/lfc/offload"
self.post(url).raise_for_status()
def offloaded():
res = self.get(url)
res.raise_for_status()
json = res.json()
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(offloaded)
def database_schema(self, database: str):
res = self.get(
f"http://localhost:{self.external_port}/database_schema?database={urllib.parse.quote(database, safe='')}",

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
Any,
)
from fixtures.endpoint.http import ComputeClaimsScope
from fixtures.pg_version import PgVersion
@@ -536,16 +535,12 @@ class NeonLocalCli(AbstractNeonCli):
res.check_returncode()
return res
def endpoint_generate_jwt(
self, endpoint_id: str, scope: ComputeClaimsScope | None = None
) -> str:
def endpoint_generate_jwt(self, endpoint_id: str) -> str:
"""
Generate a JWT for making requests to the endpoint's external HTTP
server.
"""
args = ["endpoint", "generate-jwt", endpoint_id]
if scope:
args += ["--scope", str(scope)]
cmd = self.raw_cli(args)
cmd.check_returncode()

View File

@@ -51,7 +51,7 @@ from fixtures.common_types import (
TimelineId,
)
from fixtures.compute_migrations import NUM_COMPUTE_MIGRATIONS
from fixtures.endpoint.http import ComputeClaimsScope, EndpointHttpClient
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.neon_cli import NeonLocalCli, Pagectl
@@ -1185,9 +1185,7 @@ class NeonEnv:
"broker": {},
"safekeepers": [],
"pageservers": [],
"endpoint_storage": {
"listen_addr": f"127.0.0.1:{self.port_distributor.get_port()}",
},
"endpoint_storage": {"port": self.port_distributor.get_port()},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
}
@@ -4220,7 +4218,7 @@ class Endpoint(PgProtocol, LogUtils):
self.config(config_lines)
self.__jwt = self.generate_jwt()
self.__jwt = self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id)
return self
@@ -4267,14 +4265,6 @@ class Endpoint(PgProtocol, LogUtils):
return self
def generate_jwt(self, scope: ComputeClaimsScope | None = None) -> str:
"""
Generate a JWT for making requests to the endpoint's external HTTP
server.
"""
assert self.endpoint_id is not None
return self.env.neon_cli.endpoint_generate_jwt(self.endpoint_id, scope)
def endpoint_path(self) -> Path:
"""Path to endpoint directory"""
assert self.endpoint_id

View File

@@ -1,5 +1,4 @@
import math # Add this import
import os
import time
import traceback
from pathlib import Path
@@ -88,10 +87,7 @@ def test_cumulative_statistics_persistence(
- insert additional tuples that by itself are not enough to trigger auto-vacuum but in combination with the previous tuples are
- verify that autovacuum is triggered by the combination of tuples inserted before and after endpoint suspension
"""
project = neon_api.create_project(
pg_version,
f"Test cumulative statistics persistence, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}",
)
project = neon_api.create_project(pg_version)
project_id = project["project"]["id"]
neon_api.wait_for_operation_to_finish(project_id)
endpoint_id = project["endpoints"][0]["id"]

View File

@@ -62,9 +62,7 @@ def test_ro_replica_lag(
pgbench_duration = f"-T{test_duration_min * 60 * 2}"
project = neon_api.create_project(
pg_version, f"Test readonly replica lag, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}"
)
project = neon_api.create_project(pg_version)
project_id = project["project"]["id"]
log.info("Project ID: %s", project_id)
log.info("Primary endpoint ID: %s", project["endpoints"][0]["id"])
@@ -197,9 +195,7 @@ def test_replication_start_stop(
pgbench_duration = f"-T{2**num_replicas * configuration_test_time_sec}"
error_occurred = False
project = neon_api.create_project(
pg_version, f"Test replication start stop, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}"
)
project = neon_api.create_project(pg_version)
project_id = project["project"]["id"]
log.info("Project ID: %s", project_id)
log.info("Primary endpoint ID: %s", project["endpoints"][0]["id"])

View File

@@ -206,7 +206,7 @@ class NeonProject:
self.neon_api = neon_api
self.pg_bin = pg_bin
proj = self.neon_api.create_project(
pg_version, f"Automatic random API test GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}"
pg_version, f"Automatic random API test {os.getenv('GITHUB_RUN_ID')}"
)
self.id: str = proj["project"]["id"]
self.name: str = proj["project"]["name"]

View File

@@ -202,8 +202,6 @@ def test_pageserver_gc_compaction_preempt(
env = neon_env_builder.init_start(initial_tenant_conf=conf)
env.pageserver.allowed_errors.append(".*The timeline or pageserver is shutting down.*")
env.pageserver.allowed_errors.append(".*flush task cancelled.*")
env.pageserver.allowed_errors.append(".*failed to pipe.*")
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline

View File

@@ -1,78 +0,0 @@
from __future__ import annotations
from http.client import FORBIDDEN, UNAUTHORIZED
from typing import TYPE_CHECKING
import jwt
import pytest
from fixtures.endpoint.http import COMPUTE_AUDIENCE, ComputeClaimsScope, EndpointHttpClient
from fixtures.utils import run_only_on_default_postgres
from requests import RequestException
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version")
def test_compute_no_scope_claim(neon_simple_env: NeonEnv):
"""
Test that if the JWT scope is not admin and no compute_id is specified,
the external HTTP server returns a 403 Forbidden error.
"""
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
# Encode nothing in the token
token = jwt.encode({}, env.auth_keys.priv, algorithm="EdDSA")
# Create an admin-scoped HTTP client
client = EndpointHttpClient(
external_port=endpoint.external_http_port,
internal_port=endpoint.internal_http_port,
jwt=token,
)
try:
client.status()
pytest.fail("Exception should have been raised")
except RequestException as e:
assert e.response is not None
assert e.response.status_code == FORBIDDEN
@pytest.mark.parametrize(
"audience",
(COMPUTE_AUDIENCE, "invalid", None),
ids=["with_audience", "with_invalid_audience", "without_audience"],
)
@run_only_on_default_postgres("The code path being tested is not dependent on Postgres version")
def test_compute_admin_scope_claim(neon_simple_env: NeonEnv, audience: str | None):
"""
Test that an admin-scoped JWT can access the compute's external HTTP server
without the compute_id being specified in the claims.
"""
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
data: dict[str, str | list[str]] = {"scope": str(ComputeClaimsScope.ADMIN)}
if audience:
data["aud"] = [audience]
token = jwt.encode(data, env.auth_keys.priv, algorithm="EdDSA")
# Create an admin-scoped HTTP client
client = EndpointHttpClient(
external_port=endpoint.external_http_port,
internal_port=endpoint.internal_http_port,
jwt=token,
)
try:
client.status()
if audience != COMPUTE_AUDIENCE:
pytest.fail("Exception should have been raised")
except RequestException as e:
assert e.response is not None
assert e.response.status_code == UNAUTHORIZED

View File

@@ -4,12 +4,10 @@ import pytest
from aiohttp import ClientSession
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv
from fixtures.utils import run_only_on_default_postgres
from jwcrypto import jwk, jwt
@pytest.mark.asyncio
@run_only_on_default_postgres("test doesn't use postgres")
async def test_endpoint_storage_insert_retrieve_delete(neon_simple_env: NeonEnv):
"""
Inserts, retrieves, and deletes test file using a JWT token
@@ -37,6 +35,7 @@ async def test_endpoint_storage_insert_retrieve_delete(neon_simple_env: NeonEnv)
key = f"http://{base_url}/{tenant_id}/{timeline_id}/{endpoint_id}/key"
headers = {"Authorization": f"Bearer {token}"}
log.info(f"cache key url {key}")
log.info(f"token {token}")
async with ClientSession(headers=headers) as session:
async with session.get(key) as res:

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
#
# Test unlogged build for GIST index
#
def test_gist(neon_simple_env: NeonEnv):
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
con = endpoint.connect()
cur = con.cursor()
iterations = 100
for _ in range(iterations):
cur.execute(
"CREATE TABLE pvactst (i INT, a INT[], p POINT) with (autovacuum_enabled = off)"
)
cur.execute(
"INSERT INTO pvactst SELECT i, array[1,2,3], point(i, i+1) FROM generate_series(1,1000) i"
)
cur.execute("CREATE INDEX gist_pvactst ON pvactst USING gist (p)")
cur.execute("VACUUM pvactst")
cur.execute("DROP TABLE pvactst")

View File

@@ -1,24 +1,11 @@
import random
import threading
import time
from enum import Enum
import pytest
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv
from fixtures.utils import USE_LFC
from prometheus_client.parser import text_string_to_metric_families as prom_parse_impl
class LfcQueryMethod(Enum):
COMPUTE_CTL = False
POSTGRES = True
PREWARM_LABEL = "compute_ctl_lfc_prewarm_requests_total"
OFFLOAD_LABEL = "compute_ctl_lfc_offload_requests_total"
QUERY_OPTIONS = LfcQueryMethod.POSTGRES, LfcQueryMethod.COMPUTE_CTL
def check_pinned_entries(cur):
@@ -32,20 +19,11 @@ def check_pinned_entries(cur):
assert n_pinned == 0
def prom_parse(client: EndpointHttpClient) -> dict[str, float]:
return {
sample.name: sample.value
for family in prom_parse_impl(client.metrics())
for sample in family.samples
if sample.name in (PREWARM_LABEL, OFFLOAD_LABEL)
}
@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping")
@pytest.mark.parametrize("query", QUERY_OPTIONS, ids=["postgres", "compute-ctl"])
def test_lfc_prewarm(neon_simple_env: NeonEnv, query: LfcQueryMethod):
def test_lfc_prewarm(neon_simple_env: NeonEnv):
env = neon_simple_env
n_records = 1000000
endpoint = env.endpoints.create_start(
branch_name="main",
config_lines=[
@@ -56,57 +34,30 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, query: LfcQueryMethod):
"neon.file_cache_prewarm_limit=1000",
],
)
pg_conn = endpoint.connect()
pg_cur = pg_conn.cursor()
pg_cur.execute("create extension neon version '1.6'")
pg_cur.execute("create database lfc")
lfc_conn = endpoint.connect(dbname="lfc")
lfc_cur = lfc_conn.cursor()
log.info(f"Inserting {n_records} rows")
lfc_cur.execute("create table t(pk integer primary key, payload text default repeat('?', 128))")
lfc_cur.execute(f"insert into t (pk) values (generate_series(1,{n_records}))")
log.info(f"Inserted {n_records} rows")
http_client = endpoint.http_client()
if query is LfcQueryMethod.COMPUTE_CTL:
status = http_client.prewarm_lfc_status()
assert status["status"] == "not_prewarmed"
assert "error" not in status
http_client.offload_lfc()
assert http_client.prewarm_lfc_status()["status"] == "not_prewarmed"
assert prom_parse(http_client) == {OFFLOAD_LABEL: 1, PREWARM_LABEL: 0}
else:
pg_cur.execute("select get_local_cache_state()")
lfc_state = pg_cur.fetchall()[0][0]
conn = endpoint.connect()
cur = conn.cursor()
cur.execute("create extension neon version '1.6'")
cur.execute("create table t(pk integer primary key, payload text default repeat('?', 128))")
cur.execute(f"insert into t (pk) values (generate_series(1,{n_records}))")
cur.execute("select get_local_cache_state()")
lfc_state = cur.fetchall()[0][0]
endpoint.stop()
endpoint.start()
# wait until compute_ctl completes downgrade of extension to default version
time.sleep(1)
pg_conn = endpoint.connect()
pg_cur = pg_conn.cursor()
pg_cur.execute("alter extension neon update to '1.6'")
conn = endpoint.connect()
cur = conn.cursor()
time.sleep(1) # wait until compute_ctl complete downgrade of extension to default version
cur.execute("alter extension neon update to '1.6'")
cur.execute("select prewarm_local_cache(%s)", (lfc_state,))
lfc_conn = endpoint.connect(dbname="lfc")
lfc_cur = lfc_conn.cursor()
if query is LfcQueryMethod.COMPUTE_CTL:
http_client.prewarm_lfc()
else:
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))
pg_cur.execute("select lfc_value from neon_lfc_stats where lfc_key='file_cache_used_pages'")
lfc_used_pages = pg_cur.fetchall()[0][0]
cur.execute("select lfc_value from neon_lfc_stats where lfc_key='file_cache_used_pages'")
lfc_used_pages = cur.fetchall()[0][0]
log.info(f"Used LFC size: {lfc_used_pages}")
pg_cur.execute("select * from get_prewarm_info()")
prewarm_info = pg_cur.fetchall()[0]
cur.execute("select * from get_prewarm_info()")
prewarm_info = cur.fetchall()[0]
log.info(f"Prewarm info: {prewarm_info}")
total, prewarmed, skipped, _ = prewarm_info
progress = (prewarmed + skipped) * 100 // total
log.info(f"Prewarm progress: {progress}%")
log.info(f"Prewarm progress: {(prewarm_info[1] + prewarm_info[2]) * 100 // prewarm_info[0]}%")
assert lfc_used_pages > 10000
assert (
@@ -115,23 +66,18 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, query: LfcQueryMethod):
and prewarm_info[0] == prewarm_info[1] + prewarm_info[2]
)
lfc_cur.execute("select sum(pk) from t")
assert lfc_cur.fetchall()[0][0] == n_records * (n_records + 1) / 2
cur.execute("select sum(pk) from t")
assert cur.fetchall()[0][0] == n_records * (n_records + 1) / 2
check_pinned_entries(pg_cur)
desired = {"status": "completed", "total": total, "prewarmed": prewarmed, "skipped": skipped}
if query is LfcQueryMethod.COMPUTE_CTL:
assert http_client.prewarm_lfc_status() == desired
assert prom_parse(http_client) == {OFFLOAD_LABEL: 0, PREWARM_LABEL: 1}
check_pinned_entries(cur)
@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping")
@pytest.mark.parametrize("query", QUERY_OPTIONS, ids=["postgres", "compute-ctl"])
def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMethod):
def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv):
env = neon_simple_env
n_records = 10000
n_threads = 4
endpoint = env.endpoints.create_start(
branch_name="main",
config_lines=[
@@ -141,58 +87,40 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
"neon.file_cache_prewarm_limit=1000000",
],
)
pg_conn = endpoint.connect()
pg_cur = pg_conn.cursor()
pg_cur.execute("create extension neon version '1.6'")
pg_cur.execute("CREATE DATABASE lfc")
lfc_conn = endpoint.connect(dbname="lfc")
lfc_cur = lfc_conn.cursor()
lfc_cur.execute(
conn = endpoint.connect()
cur = conn.cursor()
cur.execute("create extension neon version '1.6'")
cur.execute(
"create table accounts(id integer primary key, balance bigint default 0, payload text default repeat('?', 1000)) with (fillfactor=10)"
)
log.info(f"Inserting {n_records} rows")
lfc_cur.execute(f"insert into accounts(id) values (generate_series(1,{n_records}))")
log.info(f"Inserted {n_records} rows")
http_client = endpoint.http_client()
if query is LfcQueryMethod.COMPUTE_CTL:
http_client.offload_lfc()
else:
pg_cur.execute("select get_local_cache_state()")
lfc_state = pg_cur.fetchall()[0][0]
cur.execute(f"insert into accounts(id) values (generate_series(1,{n_records}))")
cur.execute("select get_local_cache_state()")
lfc_state = cur.fetchall()[0][0]
running = True
n_prewarms = 0
def workload():
lfc_conn = endpoint.connect(dbname="lfc")
lfc_cur = lfc_conn.cursor()
conn = endpoint.connect()
cur = conn.cursor()
n_transfers = 0
while running:
src = random.randint(1, n_records)
dst = random.randint(1, n_records)
lfc_cur.execute("update accounts set balance=balance-100 where id=%s", (src,))
lfc_cur.execute("update accounts set balance=balance+100 where id=%s", (dst,))
cur.execute("update accounts set balance=balance-100 where id=%s", (src,))
cur.execute("update accounts set balance=balance+100 where id=%s", (dst,))
n_transfers += 1
log.info(f"Number of transfers: {n_transfers}")
def prewarm():
pg_conn = endpoint.connect()
pg_cur = pg_conn.cursor()
conn = endpoint.connect()
cur = conn.cursor()
n_prewarms = 0
while running:
pg_cur.execute("alter system set neon.file_cache_size_limit='1MB'")
pg_cur.execute("select pg_reload_conf()")
pg_cur.execute("alter system set neon.file_cache_size_limit='1GB'")
pg_cur.execute("select pg_reload_conf()")
if query is LfcQueryMethod.COMPUTE_CTL:
http_client.prewarm_lfc()
else:
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))
nonlocal n_prewarms
cur.execute("alter system set neon.file_cache_size_limit='1MB'")
cur.execute("select pg_reload_conf()")
cur.execute("alter system set neon.file_cache_size_limit='1GB'")
cur.execute("select pg_reload_conf()")
cur.execute("select prewarm_local_cache(%s)", (lfc_state,))
n_prewarms += 1
log.info(f"Number of prewarms: {n_prewarms}")
@@ -212,10 +140,8 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
t.join()
prewarm_thread.join()
lfc_cur.execute("select sum(balance) from accounts")
total_balance = lfc_cur.fetchall()[0][0]
cur.execute("select sum(balance) from accounts")
total_balance = cur.fetchall()[0][0]
assert total_balance == 0
check_pinned_entries(pg_cur)
if query is LfcQueryMethod.COMPUTE_CTL:
assert prom_parse(http_client) == {OFFLOAD_LABEL: 1, PREWARM_LABEL: n_prewarms}
check_pinned_entries(cur)

View File

@@ -1334,13 +1334,6 @@ def test_sharding_split_failures(
tenant_id, timeline_id, shard_count=initial_shard_count, placement_policy='{"Attached":1}'
)
# Create bystander tenants with various shard counts. They should not be affected by the aborted
# splits. Regression test for https://github.com/neondatabase/cloud/issues/28589.
bystanders = {} # id → shard_count
for bystander_shard_count in [1, 2, 4, 8]:
id, _ = env.create_tenant(shard_count=bystander_shard_count)
bystanders[id] = bystander_shard_count
env.storage_controller.allowed_errors.extend(
[
# All split failures log a warning when then enqueue the abort operation
@@ -1401,8 +1394,6 @@ def test_sharding_split_failures(
locations = ps.http_client().tenant_list_locations()["tenant_shards"]
for loc in locations:
tenant_shard_id = TenantShardId.parse(loc[0])
if tenant_shard_id.tenant_id != tenant_id:
continue # skip bystanders
log.info(f"Shard {tenant_shard_id} seen on node {ps.id} in mode {loc[1]['mode']}")
assert tenant_shard_id.shard_count == initial_shard_count
if loc[1]["mode"] == "Secondary":
@@ -1423,8 +1414,6 @@ def test_sharding_split_failures(
locations = ps.http_client().tenant_list_locations()["tenant_shards"]
for loc in locations:
tenant_shard_id = TenantShardId.parse(loc[0])
if tenant_shard_id.tenant_id != tenant_id:
continue # skip bystanders
log.info(f"Shard {tenant_shard_id} seen on node {ps.id} in mode {loc[1]['mode']}")
assert tenant_shard_id.shard_count == split_shard_count
if loc[1]["mode"] == "Secondary":
@@ -1507,12 +1496,6 @@ def test_sharding_split_failures(
# the scheduler reaches an idle state
env.storage_controller.reconcile_until_idle(timeout_secs=30)
# Check that all bystanders are still around.
for bystander_id, bystander_shard_count in bystanders.items():
response = env.storage_controller.tenant_describe(bystander_id)
assert TenantId(response["tenant_id"]) == bystander_id
assert len(response["shards"]) == bystander_shard_count
env.storage_controller.consistency_check()

View File

@@ -5,14 +5,14 @@
],
"v16": [
"16.8",
"05ddf212e2e07b788b5c8b88bdcf98630941f6ae"
"37496f87b5324af53c56127e278ee5b1e8435253"
],
"v15": [
"15.12",
"b838c8969b7c63f3e637a769656f5f36793b797c"
"8ecb12f21d862dfa39f7204b8f5e1c00a2a225b3"
],
"v14": [
"14.17",
"c8dab02bfc003ae7bd59096919042d7840f3c194"
"d3c9d61fb7a362a165dac7060819dd9d6ad68c28"
]
}