Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Chi Z
7bd3e7a803 feat(scrubber): more parallelism for metadata check
Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-05-06 17:28:13 +08:00
125 changed files with 1627 additions and 3771 deletions

View File

@@ -53,77 +53,6 @@ concurrency:
cancel-in-progress: true
jobs:
cleanup:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
env:
ORG_ID: org-solitary-dew-09443886
LIMIT: 100
SEARCH: "GITHUB_RUN_ID="
BASE_URL: https://console-stage.neon.build/api/v2
DRY_RUN: "false" # Set to "true" to just test out the workflow
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cleanup inactive Neon projects left over from prior runs
env:
API_KEY: ${{ secrets.NEON_STAGING_API_KEY }}
run: |
set -euo pipefail
NOW=$(date -u +%s)
DAYS_AGO=$((NOW - 5 * 86400))
REQUEST_URL="$BASE_URL/projects?limit=$LIMIT&search=$(printf '%s' "$SEARCH" | jq -sRr @uri)&org_id=$ORG_ID"
echo "Requesting project list from:"
echo "$REQUEST_URL"
response=$(curl -s -X GET "$REQUEST_URL" \
--header "Accept: application/json" \
--header "Content-Type: application/json" \
--header "Authorization: Bearer ${API_KEY}" )
echo "Response:"
echo "$response" | jq .
projects_to_delete=$(echo "$response" | jq --argjson cutoff "$DAYS_AGO" '
.projects[]
| select(.compute_last_active_at != null)
| select((.compute_last_active_at | fromdateiso8601) < $cutoff)
| {id, name, compute_last_active_at}
')
if [ -z "$projects_to_delete" ]; then
echo "No projects eligible for deletion."
exit 0
fi
echo "Projects that will be deleted:"
echo "$projects_to_delete" | jq -r '.id'
if [ "$DRY_RUN" = "false" ]; then
echo "$projects_to_delete" | jq -r '.id' | while read -r project_id; do
echo "Deleting project: $project_id"
curl -s -X DELETE "$BASE_URL/projects/$project_id" \
--header "Accept: application/json" \
--header "Content-Type: application/json" \
--header "Authorization: Bearer ${API_KEY}"
done
else
echo "Dry run enabled — no projects were deleted."
fi
bench:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:

3
Cargo.lock generated
View File

@@ -1284,7 +1284,6 @@ name = "compute_tools"
version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"aws-config",
"aws-sdk-kms",
"aws-sdk-s3",
@@ -1303,7 +1302,6 @@ dependencies = [
"futures",
"http 1.1.0",
"indexmap 2.0.1",
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
"nix 0.27.1",
@@ -1422,7 +1420,6 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"endpoint_storage",
"futures",
"http-utils",
"humantime",

View File

@@ -243,7 +243,6 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver = { path = "./pageserver" }

View File

@@ -1084,12 +1084,23 @@ RUN cargo install --locked --version 0.12.9 cargo-pgrx && \
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
#########################################################################################
#
# Layer "rust extensions pgrx14"
#
# Version 14 is now required by a few
#########################################################################################
FROM pg-build-nonroot-with-cargo AS rust-extensions-build-pgrx14
ARG PG_VERSION
RUN cargo install --locked --version 0.14.1 cargo-pgrx && \
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
#########################################################################################
#
# Layer "rust extensions pgrx14"
#
# Version 14 is now required by a few
# This layer should be used as a base for new pgrx extensions,
# and eventually get merged with `rust-extensions-build`
#
@@ -1322,8 +1333,8 @@ ARG PG_VERSION
# Do not update without approve from proxy team
# Make sure the version is reflected in proxy/src/serverless/local_conn_pool.rs
WORKDIR /ext-src
RUN wget https://github.com/neondatabase/pg_session_jwt/archive/refs/tags/v0.3.1.tar.gz -O pg_session_jwt.tar.gz && \
echo "62fec9e472cb805c53ba24a0765afdb8ea2720cfc03ae7813e61687b36d1b0ad pg_session_jwt.tar.gz" | sha256sum --check && \
RUN wget https://github.com/neondatabase/pg_session_jwt/archive/refs/tags/v0.3.0.tar.gz -O pg_session_jwt.tar.gz && \
echo "19be2dc0b3834d643706ed430af998bb4c2cdf24b3c45e7b102bb3a550e8660c pg_session_jwt.tar.gz" | sha256sum --check && \
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
sed -i 's/pgrx = "0.12.6"/pgrx = { version = "0.12.9", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/version = "0.12.6"/version = "0.12.9"/g' pgrx-tests/Cargo.toml && \
@@ -1351,8 +1362,7 @@ COPY compute/patches/anon_v2.patch .
# This is an experimental extension, never got to real production.
# !Do not remove! It can be present in shared_preload_libraries and compute will fail to start if library is not found.
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/2.1.0/postgresql_anonymizer-latest.tar.gz -O pg_anon.tar.gz && \
echo "48e7f5ae2f1ca516df3da86c5c739d48dd780a4e885705704ccaad0faa89d6c0 pg_anon.tar.gz" | sha256sum --check && \
RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/latest/postgresql_anonymizer-latest.tar.gz -O pg_anon.tar.gz && \
mkdir pg_anon-src && cd pg_anon-src && tar xzf ../pg_anon.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "=0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
@@ -1971,8 +1981,7 @@ COPY --from=sql_exporter_preprocessor --chmod=0644 /home/nonroot/compute/etc/sql
COPY --from=sql_exporter_preprocessor --chmod=0644 /home/nonroot/compute/etc/neon_collector_autoscaling.yml /etc/neon_collector_autoscaling.yml
# Make the libraries we built available
COPY --chmod=0666 compute/etc/ld.so.conf.d/00-neon.conf /etc/ld.so.conf.d/00-neon.conf
RUN /sbin/ldconfig
RUN echo '/usr/local/lib' >> /etc/ld.so.conf && /sbin/ldconfig
# rsyslog config permissions
# directory for rsyslogd pid file

View File

@@ -1 +0,0 @@
/usr/local/lib

View File

@@ -10,7 +10,6 @@ default = []
testing = ["fail/failpoints"]
[dependencies]
async-compression.workspace = true
base64.workspace = true
aws-config.workspace = true
aws-sdk-s3.workspace = true
@@ -28,7 +27,6 @@ flate2.workspace = true
futures.workspace = true
http.workspace = true
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
metrics.workspace = true
nix.workspace = true

View File

@@ -60,16 +60,12 @@ use utils::failpoint_support;
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
Ok(if arg.starts_with("http") {
arg
fn parse_remote_ext_config(arg: &str) -> Result<String> {
if arg.starts_with("http") {
Ok(arg.trim_end_matches('/').to_string())
} else {
FALLBACK_PG_EXT_GATEWAY_BASE_URL
Ok("http://pg-ext-s3-gateway".to_string())
}
.to_owned())
}
#[derive(Parser)]
@@ -78,10 +74,8 @@ struct Cli {
#[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")]
pub pgbin: String,
/// The base URL for the remote extension storage proxy gateway.
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
pub remote_ext_base_url: Option<String>,
#[arg(short = 'r', long, value_parser = parse_remote_ext_config)]
pub remote_ext_config: Option<String>,
/// The port to bind the external listening HTTP server to. Clients running
/// outside the compute will talk to the compute through this port. Keep
@@ -170,7 +164,7 @@ fn main() -> Result<()> {
pgversion: get_pg_version_string(&cli.pgbin),
external_http_port: cli.external_http_port,
internal_http_port: cli.internal_http_port,
remote_ext_base_url: cli.remote_ext_base_url.clone(),
ext_remote_storage: cli.remote_ext_config.clone(),
resize_swap_on_bind: cli.resize_swap_on_bind,
set_disk_quota_for_fs: cli.set_disk_quota_for_fs,
#[cfg(target_os = "linux")]
@@ -271,18 +265,4 @@ mod test {
fn verify_cli() {
Cli::command().debug_assert()
}
#[test]
fn parse_pg_ext_gateway_base_url() {
let arg = "http://pg-ext-s3-gateway2";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(result, arg);
let arg = "pg-ext-s3-gateway";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(
result,
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
);
}
}

View File

@@ -348,7 +348,6 @@ async fn run_dump_restore(
"--no-security-labels".to_string(),
"--no-subscriptions".to_string(),
"--no-tablespaces".to_string(),
"--no-event-triggers".to_string(),
// format
"--format".to_string(),
"directory".to_string(),

View File

@@ -1,26 +1,4 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege;
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus, LfcOffloadState,
LfcPrewarmState,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
use futures::stream::FuturesUnordered;
use itertools::Itertools;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::collections::HashMap;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
@@ -29,6 +7,24 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege;
use compute_api::responses::{ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
use futures::stream::FuturesUnordered;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use tokio::spawn;
use tracing::{Instrument, debug, error, info, instrument, warn};
use utils::id::{TenantId, TimelineId};
@@ -96,7 +92,7 @@ pub struct ComputeNodeParams {
pub internal_http_port: u16,
/// the address of extension storage proxy gateway
pub remote_ext_base_url: Option<String>,
pub ext_remote_storage: Option<String>,
}
/// Compute node info shared across several `compute_ctl` threads.
@@ -154,9 +150,6 @@ pub struct ComputeState {
/// set up the span relationship ourselves.
pub startup_span: Option<tracing::span::Span>,
pub lfc_prewarm_state: LfcPrewarmState,
pub lfc_offload_state: LfcOffloadState,
pub metrics: ComputeMetrics,
}
@@ -170,8 +163,6 @@ impl ComputeState {
pspec: None,
startup_span: None,
metrics: ComputeMetrics::default(),
lfc_prewarm_state: LfcPrewarmState::default(),
lfc_offload_state: LfcOffloadState::default(),
}
}
@@ -207,8 +198,6 @@ pub struct ParsedSpec {
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
pub endpoint_storage_addr: Option<SocketAddr>,
pub endpoint_storage_token: Option<String>,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
@@ -262,18 +251,6 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<SocketAddr> = spec
.endpoint_storage_addr
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"))
.unwrap_or_default()
.parse()
.ok();
let endpoint_storage_token = spec
.endpoint_storage_token
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_token"));
Ok(ParsedSpec {
spec,
pageserver_connstr,
@@ -281,8 +258,6 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
storage_auth_token,
tenant_id,
timeline_id,
endpoint_storage_addr,
endpoint_storage_token,
})
}
}
@@ -330,39 +305,11 @@ struct StartVmMonitorResult {
impl ComputeNode {
pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result<Self> {
let connstr = params.connstr.as_str();
let mut conn_conf = postgres::config::Config::from_str(connstr)
let conn_conf = postgres::config::Config::from_str(connstr)
.context("cannot build postgres config from connstr")?;
let mut tokio_conn_conf = tokio_postgres::config::Config::from_str(connstr)
let tokio_conn_conf = tokio_postgres::config::Config::from_str(connstr)
.context("cannot build tokio postgres config from connstr")?;
// Users can set some configuration parameters per database with
// ALTER DATABASE ... SET ...
//
// There are at least these parameters:
//
// - role=some_other_role
// - default_transaction_read_only=on
// - statement_timeout=1, i.e., 1ms, which will cause most of the queries to fail
// - search_path=non_public_schema, this should be actually safe because
// we don't call any functions in user databases, but better to always reset
// it to public.
//
// that can affect `compute_ctl` and prevent it from properly configuring the database schema.
// Unset them via connection string options before connecting to the database.
// N.B. keep it in sync with `ZENITH_OPTIONS` in `get_maintenance_client()`.
//
// TODO(ololobus): we currently pass `-c default_transaction_read_only=off` from control plane
// as well. After rolling out this code, we can remove this parameter from control plane.
// In the meantime, double-passing is fine, the last value is applied.
// See: <https://github.com/neondatabase/cloud/blob/133dd8c4dbbba40edfbad475bf6a45073ca63faf/goapp/controlplane/internal/pkg/compute/provisioner/provisioner_common.go#L70>
const EXTRA_OPTIONS: &str = "-c role=cloud_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0";
let options = match conn_conf.get_options() {
Some(options) => format!("{} {}", options, EXTRA_OPTIONS),
None => EXTRA_OPTIONS.to_string(),
};
conn_conf.options(&options);
tokio_conn_conf.options(&options);
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
@@ -789,9 +736,6 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
if pspec.spec.prewarm_lfc_on_startup {
self.prewarm_lfc();
}
Ok(())
}
@@ -1478,20 +1422,15 @@ impl ComputeNode {
Err(e) => match e.code() {
Some(&SqlState::INVALID_PASSWORD)
| Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => {
// Connect with `zenith_admin` if `cloud_admin` could not authenticate
// Connect with zenith_admin if cloud_admin could not authenticate
info!(
"cannot connect to Postgres: {}, retrying with 'zenith_admin' username",
"cannot connect to postgres: {}, retrying with `zenith_admin` username",
e
);
let mut zenith_admin_conf = postgres::config::Config::from(conf.clone());
zenith_admin_conf.application_name("compute_ctl:apply_config");
zenith_admin_conf.user("zenith_admin");
// It doesn't matter what were the options before, here we just want
// to connect and create a new superuser role.
const ZENITH_OPTIONS: &str = "-c role=zenith_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0";
zenith_admin_conf.options(ZENITH_OPTIONS);
let mut client =
zenith_admin_conf.connect(NoTls)
.context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?;
@@ -1657,7 +1596,9 @@ impl ComputeNode {
self.pg_reload_conf()?;
if spec.mode == ComputeMode::Primary {
let conf = self.get_tokio_conn_conf(Some("compute_ctl:reconfigure"));
let mut conf =
tokio_postgres::Config::from_str(self.params.connstr.as_str()).unwrap();
conf.application_name("apply_config");
let conf = Arc::new(conf);
let spec = Arc::new(spec.clone());
@@ -1897,9 +1838,9 @@ LIMIT 100",
real_ext_name: String,
ext_path: RemotePath,
) -> Result<u64, DownloadError> {
let remote_ext_base_url =
let ext_remote_storage =
self.params
.remote_ext_base_url
.ext_remote_storage
.as_ref()
.ok_or(DownloadError::BadInput(anyhow::anyhow!(
"Remote extensions storage is not configured",
@@ -1961,7 +1902,7 @@ LIMIT 100",
let download_size = extension_server::download_extension(
&real_ext_name,
&ext_path,
remote_ext_base_url,
ext_remote_storage,
&self.params.pgbin,
)
.await
@@ -1996,40 +1937,23 @@ LIMIT 100",
tokio::spawn(conn);
// TODO: support other types of grants apart from schemas?
// check the role grants first - to gracefully handle read-replicas.
let select = "SELECT privilege_type
FROM pg_namespace
JOIN LATERAL (SELECT * FROM aclexplode(nspacl) AS x) acl ON true
JOIN pg_user users ON acl.grantee = users.usesysid
WHERE users.usename = $1
AND nspname = $2";
let rows = db_client
.query(select, &[role_name, schema_name])
.await
.with_context(|| format!("Failed to execute query: {select}"))?;
let already_granted: HashSet<String> = rows.into_iter().map(|row| row.get(0)).collect();
let grants = privileges
.iter()
.filter(|p| !already_granted.contains(p.as_str()))
// should not be quoted as it's part of the command.
// is already sanitized so it's ok
.map(|p| p.as_str())
.join(", ");
if !grants.is_empty() {
let query = format!(
"GRANT {} ON SCHEMA {} TO {}",
privileges
.iter()
// should not be quoted as it's part of the command.
// is already sanitized so it's ok
.map(|p| p.as_str())
.collect::<Vec<&'static str>>()
.join(", "),
// quote the schema and role name as identifiers to sanitize them.
let schema_name = schema_name.pg_quote();
let role_name = role_name.pg_quote();
let query = format!("GRANT {grants} ON SCHEMA {schema_name} TO {role_name}",);
db_client
.simple_query(&query)
.await
.with_context(|| format!("Failed to execute query: {}", query))?;
}
schema_name.pg_quote(),
role_name.pg_quote(),
);
db_client
.simple_query(&query)
.await
.with_context(|| format!("Failed to execute query: {}", query))?;
Ok(())
}
@@ -2087,7 +2011,7 @@ LIMIT 100",
&self,
spec: &ComputeSpec,
) -> Result<RemoteExtensionMetrics> {
if self.params.remote_ext_base_url.is_none() {
if self.params.ext_remote_storage.is_none() {
return Ok(RemoteExtensionMetrics {
num_ext_downloaded: 0,
largest_ext_size: 0,

View File

@@ -1,202 +0,0 @@
use crate::compute::ComputeNode;
use anyhow::{Context, Result, bail};
use async_compression::tokio::bufread::{ZstdDecoder, ZstdEncoder};
use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use http::StatusCode;
use reqwest::Client;
use std::sync::Arc;
use tokio::{io::AsyncReadExt, spawn};
use tracing::{error, info};
#[derive(serde::Serialize, Default)]
pub struct LfcPrewarmStateWithProgress {
#[serde(flatten)]
base: LfcPrewarmState,
total: i32,
prewarmed: i32,
skipped: i32,
}
/// A pair of url and a token to query endpoint storage for LFC prewarm-related tasks
struct EndpointStoragePair {
url: String,
token: String,
}
const KEY: &str = "lfc_state";
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
};
let tenant_id = pspec.tenant_id;
let timeline_id = pspec.timeline_id;
let url = format!("http://{base_uri}/{tenant_id}/{timeline_id}/{endpoint_id}/{KEY}");
let Some(ref token) = pspec.endpoint_storage_token else {
bail!("pspec.endpoint_storage_token missing")
};
let token = token.clone();
Ok(EndpointStoragePair { url, token })
}
}
impl ComputeNode {
// If prewarm failed, we want to get overall number of segments as well as done ones.
// However, this function should be reliable even if querying postgres failed.
pub async fn lfc_prewarm_state(&self) -> LfcPrewarmStateWithProgress {
info!("requesting LFC prewarm state from postgres");
let mut state = LfcPrewarmStateWithProgress::default();
{
state.base = self.state.lock().unwrap().lfc_prewarm_state.clone();
}
let client = match ComputeNode::get_maintenance_client(&self.tokio_conn_conf).await {
Ok(client) => client,
Err(err) => {
error!(%err, "connecting to postgres");
return state;
}
};
let row = match client
.query_one("select * from get_prewarm_info()", &[])
.await
{
Ok(row) => row,
Err(err) => {
error!(%err, "querying LFC prewarm status");
return state;
}
};
state.total = row.try_get(0).unwrap_or_default();
state.prewarmed = row.try_get(1).unwrap_or_default();
state.skipped = row.try_get(2).unwrap_or_default();
state
}
pub fn lfc_offload_state(&self) -> LfcOffloadState {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
if let LfcPrewarmState::Prewarming =
std::mem::replace(state, LfcPrewarmState::Prewarming)
{
return false;
}
}
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl().await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
error!(%err);
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
};
});
true
}
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
state.pspec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
}
let mut uncompressed = Vec::new();
let lfc_state = res
.bytes()
.await
.context("getting request body from endpoint storage")?;
ZstdDecoder::new(lfc_state.iter().as_slice())
.read_to_end(&mut uncompressed)
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
.query_one("select prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())
}
/// Returns false if there is an offload request ongoing, true otherwise
pub fn offload_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_OFFLOAD_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if let LfcOffloadState::Offloading =
std::mem::replace(state, LfcOffloadState::Offloading)
{
return false;
}
}
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.offload_lfc_impl().await else {
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
error!(%err);
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
});
true
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
.context("connecting to postgres")?
.query_one("select get_local_cache_state()", &[])
.await
.context("querying LFC state")?
.try_get::<usize, &[u8]>(0)
.context("deserializing LFC state")
.map(ZstdEncoder::new)?
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}
}

View File

@@ -223,12 +223,6 @@ pub fn write_postgres_conf(
// TODO: tune this after performance testing
writeln!(file, "pgaudit.log_rotation_age=5")?;
// Enable audit logs for pg_session_jwt extension
// TODO: Consider a good approach for shipping pg_session_jwt logs to the same sink as
// pgAudit - additional context in https://github.com/neondatabase/cloud/issues/28863
//
// writeln!(file, "pg_session_jwt.audit_log=on")?;
// Add audit shared_preload_libraries, if they are not present.
//
// The caller who sets the flag is responsible for ensuring that the necessary

View File

@@ -158,14 +158,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
pub async fn download_extension(
ext_name: &str,
ext_path: &RemotePath,
remote_ext_base_url: &str,
ext_remote_storage: &str,
pgbin: &str,
) -> Result<u64> {
info!("Download extension {:?} from {:?}", ext_name, ext_path);
// TODO add retry logic
let download_buffer =
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
match download_extension_tar(ext_remote_storage, &ext_path.to_string()).await {
Ok(buffer) => buffer,
Err(error_message) => {
return Err(anyhow::anyhow!(
@@ -272,8 +272,8 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) {
// Do request to extension storage proxy, e.g.,
// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst
// using HTTP GET and return the response body as bytes.
async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result<Bytes> {
let uri = format!("{}/{}", remote_ext_base_url, ext_path);
async fn download_extension_tar(ext_remote_storage: &str, ext_path: &str) -> Result<Bytes> {
let uri = format!("{}/{}", ext_remote_storage, ext_path);
let filename = Path::new(ext_path)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("unknown"))

View File

@@ -1,10 +1,12 @@
use std::collections::HashSet;
use anyhow::{Result, anyhow};
use axum::{RequestExt, body::Body};
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use compute_api::requests::{COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope};
use compute_api::requests::ComputeClaims;
use futures::future::BoxFuture;
use http::{Request, Response, StatusCode};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, jwk::JwkSet};
@@ -23,14 +25,13 @@ pub(in crate::http) struct Authorize {
impl Authorize {
pub fn new(compute_id: String, jwks: JwkSet) -> Self {
let mut validation = Validation::new(Algorithm::EdDSA);
// Nothing is currently required
validation.required_spec_claims = HashSet::new();
validation.validate_exp = true;
// Unused by the control plane
validation.validate_nbf = false;
// Unused by the control plane
validation.validate_aud = false;
validation.set_audience(&[COMPUTE_AUDIENCE]);
// Nothing is currently required
validation.set_required_spec_claims(&[] as &[&str; 0]);
// Unused by the control plane
validation.validate_nbf = false;
Self {
compute_id,
@@ -63,47 +64,11 @@ impl AsyncAuthorizeRequest<Body> for Authorize {
Err(e) => return Err(JsonResponse::error(StatusCode::UNAUTHORIZED, e)),
};
match data.claims.scope {
// TODO: We should validate audience for every token, but
// instead of this ad-hoc validation, we should turn
// [`Validation::validate_aud`] on. This is merely a stopgap
// while we roll out `aud` deployment. We return a 401
// Unauthorized because when we eventually do use
// [`Validation`], we will hit the above `Err` match arm which
// returns 401 Unauthorized.
Some(ComputeClaimsScope::Admin) => {
let Some(ref audience) = data.claims.audience else {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"missing audience in authorization token claims",
));
};
if !audience.iter().any(|a| a == COMPUTE_AUDIENCE) {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"invalid audience in authorization token claims",
));
}
}
// If the scope is not [`ComputeClaimsScope::Admin`], then we
// must validate the compute_id
_ => {
let Some(ref claimed_compute_id) = data.claims.compute_id else {
return Err(JsonResponse::error(
StatusCode::FORBIDDEN,
"missing compute_id in authorization token claims",
));
};
if *claimed_compute_id != compute_id {
return Err(JsonResponse::error(
StatusCode::FORBIDDEN,
"invalid compute ID in authorization token claims",
));
}
}
if data.claims.compute_id != compute_id {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,
"invalid compute ID in authorization token claims",
));
}
// Make claims available to any subsequent middleware or request

View File

@@ -22,7 +22,7 @@ pub(in crate::http) async fn download_extension(
State(compute): State<Arc<ComputeNode>>,
) -> Response {
// Don't even try to download extensions if no remote storage is configured
if compute.params.remote_ext_base_url.is_none() {
if compute.params.ext_remote_storage.is_none() {
return JsonResponse::error(
StatusCode::PRECONDITION_FAILED,
"remote storage is not configured",

View File

@@ -1,39 +0,0 @@
use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json<LfcPrewarmStateWithProgress> {
Json(compute.lfc_prewarm_state().await)
}
// Following functions are marked async for axum, as it's more convenient than wrapping these
// in async lambdas at call site
pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadState> {
Json(compute.lfc_offload_state())
}
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(
StatusCode::TOO_MANY_REQUESTS,
"Multiple requests for prewarm are not allowed",
)
}
}
pub(in crate::http) async fn offload(compute: Compute) -> Response {
if compute.offload_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(
StatusCode::TOO_MANY_REQUESTS,
"Multiple requests for prewarm offload are not allowed",
)
}
}

View File

@@ -11,7 +11,6 @@ pub(in crate::http) mod extensions;
pub(in crate::http) mod failpoints;
pub(in crate::http) mod grants;
pub(in crate::http) mod insights;
pub(in crate::http) mod lfc;
pub(in crate::http) mod metrics;
pub(in crate::http) mod metrics_json;
pub(in crate::http) mod status;

View File

@@ -23,7 +23,7 @@ use super::{
middleware::authorize::Authorize,
routes::{
check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
grants, insights, lfc, metrics, metrics_json, status, terminate,
grants, insights, metrics, metrics_json, status, terminate,
},
};
use crate::compute::ComputeNode;
@@ -85,8 +85,6 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))
.route("/lfc/offload", get(lfc::offload_state).post(lfc::offload))
.route("/check_writability", post(check_writability::is_writable))
.route("/configure", post(configure::configure))
.route("/database_schema", get(database_schema::get_schema_dump))

View File

@@ -11,7 +11,6 @@ pub mod http;
pub mod logger;
pub mod catalog;
pub mod compute;
pub mod compute_prewarm;
pub mod disk_quota;
pub mod extension_server;
pub mod installed_extensions;

View File

@@ -1,7 +1,7 @@
use metrics::core::{AtomicF64, AtomicU64, Collector, GenericCounter, GenericGauge};
use metrics::proto::MetricFamily;
use metrics::{
IntCounter, IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter,
IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter,
register_int_counter_vec, register_int_gauge_vec, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
@@ -97,24 +97,6 @@ pub(crate) static PG_TOTAL_DOWNTIME_MS: Lazy<GenericCounter<AtomicU64>> = Lazy::
.expect("failed to define a metric")
});
/// Needed as neon.file_cache_prewarm_batch == 0 doesn't mean we never tried to prewarm.
/// On the other hand, LFC_PREWARMED_PAGES is excessive as we can GET /lfc/prewarm
pub(crate) static LFC_PREWARM_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_prewarm_requests_total",
"Total number of LFC prewarm requests made by compute_ctl",
)
.expect("failed to define a metric")
});
pub(crate) static LFC_OFFLOAD_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_offload_requests_total",
"Total number of LFC offload requests made by compute_ctl",
)
.expect("failed to define a metric")
});
pub fn collect() -> Vec<MetricFamily> {
let mut metrics = COMPUTE_CTL_UP.collect();
metrics.extend(INSTALLED_EXTENSIONS.collect());
@@ -124,7 +106,5 @@ pub fn collect() -> Vec<MetricFamily> {
metrics.extend(AUDIT_LOG_DIR_SIZE.collect());
metrics.extend(PG_CURR_DOWNTIME_MS.collect());
metrics.extend(PG_TOTAL_DOWNTIME_MS.collect());
metrics.extend(LFC_PREWARM_REQUESTS.collect());
metrics.extend(LFC_OFFLOAD_REQUESTS.collect());
metrics
}

View File

@@ -30,7 +30,6 @@ mod pg_helpers_tests {
r#"fsync = off
wal_level = logical
hot_standby = on
prewarm_lfc_on_startup = off
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on

View File

@@ -41,7 +41,7 @@ storage_broker.workspace = true
http-utils.workspace = true
utils.workspace = true
whoami.workspace = true
endpoint_storage.workspace = true
compute_api.workspace = true
workspace_hack.workspace = true
tracing.workspace = true

View File

@@ -16,11 +16,10 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use control_plane::broker::StorageBroker;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_PORT, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
EndpointStorageConf, InitForceMode, LocalEnv, NeonBroker, NeonLocalInitConf,
@@ -644,10 +643,9 @@ struct EndpointStartCmdArgs {
#[clap(
long,
help = "Configure the remote extensions storage proxy gateway URL to request for extensions.",
alias = "remote-ext-config"
help = "Configure the remote extensions storage proxy gateway to request for extensions."
)]
remote_ext_base_url: Option<String>,
remote_ext_config: Option<String>,
#[clap(
long,
@@ -707,9 +705,6 @@ struct EndpointStopCmdArgs {
struct EndpointGenerateJwtCmdArgs {
#[clap(help = "Postgres endpoint id")]
endpoint_id: String,
#[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
scope: Option<ComputeClaimsScope>,
}
#[derive(clap::Subcommand)]
@@ -1023,7 +1018,7 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
})
.collect(),
endpoint_storage: EndpointStorageConf {
listen_addr: ENDPOINT_STORAGE_DEFAULT_ADDR,
port: ENDPOINT_STORAGE_DEFAULT_PORT,
},
pg_distrib_dir: None,
neon_distrib_dir: None,
@@ -1415,16 +1410,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
EndpointCmd::Start(args) => {
let endpoint_id = &args.endpoint_id;
let pageserver_id = args.endpoint_pageserver_id;
let remote_ext_base_url = &args.remote_ext_base_url;
let remote_ext_config = &args.remote_ext_config;
let default_generation = env
.storage_controller
.timelines_onto_safekeepers
.then_some(1);
let safekeepers_generation = args
.safekeepers_generation
.or(default_generation)
.map(SafekeeperGeneration::new);
let safekeepers_generation = args.safekeepers_generation.map(SafekeeperGeneration::new);
// If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env.
let safekeepers = if let Some(safekeepers) = parse_safekeepers(&args.safekeepers)? {
@@ -1496,29 +1484,14 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
None
};
let exp = (std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?
+ Duration::from_secs(86400))
.as_secs();
let claims = endpoint_storage::claims::EndpointStorageClaims {
tenant_id: endpoint.tenant_id,
timeline_id: endpoint.timeline_id,
endpoint_id: endpoint_id.to_string(),
exp,
};
let endpoint_storage_token = env.generate_auth_token(&claims)?;
let endpoint_storage_addr = env.endpoint_storage.listen_addr.to_string();
println!("Starting existing endpoint {endpoint_id}...");
endpoint
.start(
&auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageservers,
remote_ext_base_url.as_ref(),
remote_ext_config.as_ref(),
stripe_size.0 as usize,
args.create_test_user,
args.start_timeout,
@@ -1567,16 +1540,12 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
endpoint.stop(&args.mode, args.destroy)?;
}
EndpointCmd::GenerateJwt(args) => {
let endpoint = {
let endpoint_id = &args.endpoint_id;
cplane
.endpoints
.get(endpoint_id)
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?
};
let jwt = endpoint.generate_jwt(args.scope)?;
let endpoint_id = &args.endpoint_id;
let endpoint = cplane
.endpoints
.get(endpoint_id)
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let jwt = endpoint.generate_jwt()?;
print!("{jwt}");
}

View File

@@ -45,9 +45,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
use compute_api::requests::{ComputeClaims, ConfigurationRequest};
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig,
};
@@ -632,17 +630,9 @@ impl Endpoint {
}
/// Generate a JWT with the correct claims.
pub fn generate_jwt(&self, scope: Option<ComputeClaimsScope>) -> Result<String> {
pub fn generate_jwt(&self) -> Result<String> {
self.env.generate_auth_token(&ComputeClaims {
audience: match scope {
Some(ComputeClaimsScope::Admin) => Some(vec![COMPUTE_AUDIENCE.to_owned()]),
_ => None,
},
compute_id: match scope {
Some(ComputeClaimsScope::Admin) => None,
_ => Some(self.endpoint_id.clone()),
},
scope,
compute_id: self.endpoint_id.clone(),
})
}
@@ -650,12 +640,10 @@ impl Endpoint {
pub async fn start(
&self,
auth_token: &Option<String>,
endpoint_storage_token: String,
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(Host, u16)>,
remote_ext_base_url: Option<&String>,
remote_ext_config: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
start_timeout: Duration,
@@ -745,9 +733,6 @@ impl Endpoint {
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
audit_log_level: ComputeAudit::Disabled,
logs_export_host: None::<String>,
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
prewarm_lfc_on_startup: false,
};
// this strange code is needed to support respec() in tests
@@ -825,8 +810,8 @@ impl Endpoint {
.stderr(logfile.try_clone()?)
.stdout(logfile);
if let Some(remote_ext_base_url) = remote_ext_base_url {
cmd.args(["--remote-ext-base-url", remote_ext_base_url]);
if let Some(remote_ext_config) = remote_ext_config {
cmd.args(["--remote-ext-config", remote_ext_config]);
}
let child = cmd.spawn()?;
@@ -918,7 +903,7 @@ impl Endpoint {
self.external_http_address.port()
),
)
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.bearer_auth(self.generate_jwt()?)
.send()
.await?;
@@ -995,7 +980,7 @@ impl Endpoint {
self.external_http_address.port()
))
.header(CONTENT_TYPE.as_str(), "application/json")
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.bearer_auth(self.generate_jwt()?)
.body(
serde_json::to_string(&ConfigurationRequest {
spec,

View File

@@ -3,19 +3,17 @@ use crate::local_env::LocalEnv;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use std::io::Write;
use std::net::SocketAddr;
use std::time::Duration;
/// Directory within .neon which will be used by default for LocalFs remote storage.
pub const ENDPOINT_STORAGE_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/endpoint_storage";
pub const ENDPOINT_STORAGE_DEFAULT_ADDR: SocketAddr =
SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 9993);
pub const ENDPOINT_STORAGE_DEFAULT_PORT: u16 = 9993;
pub struct EndpointStorage {
pub bin: Utf8PathBuf,
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub addr: SocketAddr,
pub port: u16,
}
impl EndpointStorage {
@@ -24,7 +22,7 @@ impl EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
addr: env.endpoint_storage.listen_addr,
port: env.endpoint_storage.port,
}
}
@@ -33,7 +31,7 @@ impl EndpointStorage {
}
fn listen_addr(&self) -> Utf8PathBuf {
format!("{}:{}", self.addr.ip(), self.addr.port()).into()
format!("127.0.0.1:{}", self.port).into()
}
pub fn init(&self) -> Result<()> {

View File

@@ -20,9 +20,7 @@ use utils::auth::encode_from_key_file;
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::broker::StorageBroker;
use crate::endpoint_storage::{
ENDPOINT_STORAGE_DEFAULT_ADDR, ENDPOINT_STORAGE_REMOTE_STORAGE_DIR, EndpointStorage,
};
use crate::endpoint_storage::{ENDPOINT_STORAGE_REMOTE_STORAGE_DIR, EndpointStorage};
use crate::pageserver::{PAGESERVER_REMOTE_STORAGE_DIR, PageServerNode};
use crate::safekeeper::SafekeeperNode;
@@ -153,10 +151,10 @@ pub struct NeonLocalInitConf {
pub generate_local_ssl_certs: bool,
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
#[derive(Serialize, Default, Deserialize, PartialEq, Eq, Clone, Debug)]
#[serde(default)]
pub struct EndpointStorageConf {
pub listen_addr: SocketAddr,
pub port: u16,
}
/// Broker config for cluster internal communication.
@@ -243,14 +241,6 @@ impl Default for NeonStorageControllerConf {
}
}
impl Default for EndpointStorageConf {
fn default() -> Self {
Self {
listen_addr: ENDPOINT_STORAGE_DEFAULT_ADDR,
}
}
}
impl NeonBroker {
pub fn client_url(&self) -> Url {
let url = if let Some(addr) = self.listen_https_addr {

View File

@@ -10,8 +10,7 @@ use camino::{Utf8Path, Utf8PathBuf};
use hyper0::Uri;
use nix::unistd::Pid;
use pageserver_api::controller_api::{
NodeConfigureRequest, NodeDescribeResponse, NodeRegisterRequest,
SafekeeperSchedulingPolicyRequest, SkSchedulingPolicy, TenantCreateRequest,
NodeConfigureRequest, NodeDescribeResponse, NodeRegisterRequest, TenantCreateRequest,
TenantCreateResponse, TenantLocateResponse,
};
use pageserver_api::models::{
@@ -21,7 +20,7 @@ use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
use pem::Pem;
use postgres_backend::AuthType;
use reqwest::{Method, Response};
use reqwest::Method;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
@@ -571,11 +570,6 @@ impl StorageController {
let peer_jwt_token = encode_from_key_file(&peer_claims, private_key)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
let claims = Claims::new(None, Scope::SafekeeperData);
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
args.push(format!("--safekeeper-jwt-token={jwt_token}"));
}
if let Some(public_key) = &self.public_key {
@@ -620,10 +614,6 @@ impl StorageController {
self.env.base_data_dir.display()
));
if self.env.safekeepers.iter().any(|sk| sk.auth_enabled) && self.private_key.is_none() {
anyhow::bail!("Safekeeper set up for auth but no private key specified");
}
if self.config.timelines_onto_safekeepers {
args.push("--timelines-onto-safekeepers".to_string());
}
@@ -650,10 +640,6 @@ impl StorageController {
)
.await?;
if self.config.timelines_onto_safekeepers {
self.register_safekeepers().await?;
}
Ok(())
}
@@ -757,23 +743,6 @@ impl StorageController {
where
RQ: Serialize + Sized,
RS: DeserializeOwned + Sized,
{
let response = self.dispatch_inner(method, path, body).await?;
Ok(response
.json()
.await
.map_err(pageserver_client::mgmt_api::Error::ReceiveBody)?)
}
/// Simple HTTP request wrapper for calling into storage controller
async fn dispatch_inner<RQ>(
&self,
method: reqwest::Method,
path: String,
body: Option<RQ>,
) -> anyhow::Result<Response>
where
RQ: Serialize + Sized,
{
// In the special case of the `storage_controller start` subcommand, we wish
// to use the API endpoint of the newly started storage controller in order
@@ -816,31 +785,10 @@ impl StorageController {
let response = builder.send().await?;
let response = response.error_from_body().await?;
Ok(response)
}
/// Register the safekeepers in the storage controller
#[instrument(skip(self))]
async fn register_safekeepers(&self) -> anyhow::Result<()> {
for sk in self.env.safekeepers.iter() {
let sk_id = sk.id;
let body = serde_json::json!({
"id": sk_id,
"created_at": "2023-10-25T09:11:25Z",
"updated_at": "2024-08-28T11:32:43Z",
"region_id": "aws-us-east-2",
"host": "127.0.0.1",
"port": sk.pg_port,
"http_port": sk.http_port,
"https_port": sk.https_port,
"version": 5957,
"availability_zone_id": format!("us-east-2b-{sk_id}"),
});
self.upsert_safekeeper(sk_id, body).await?;
self.safekeeper_scheduling_policy(sk_id, SkSchedulingPolicy::Active)
.await?;
}
Ok(())
Ok(response
.json()
.await
.map_err(pageserver_client::mgmt_api::Error::ReceiveBody)?)
}
/// Call into the attach_hook API, for use before handing out attachments to pageservers
@@ -868,42 +816,6 @@ impl StorageController {
Ok(response.generation)
}
#[instrument(skip(self))]
pub async fn upsert_safekeeper(
&self,
node_id: NodeId,
request: serde_json::Value,
) -> anyhow::Result<()> {
let resp = self
.dispatch_inner::<serde_json::Value>(
Method::POST,
format!("control/v1/safekeeper/{node_id}"),
Some(request),
)
.await?;
if !resp.status().is_success() {
anyhow::bail!(
"setting scheduling policy unsuccessful for safekeeper {node_id}: {}",
resp.status()
);
}
Ok(())
}
#[instrument(skip(self))]
pub async fn safekeeper_scheduling_policy(
&self,
node_id: NodeId,
scheduling_policy: SkSchedulingPolicy,
) -> anyhow::Result<()> {
self.dispatch::<SafekeeperSchedulingPolicyRequest, ()>(
Method::POST,
format!("control/v1/safekeeper/{node_id}/scheduling_policy"),
Some(SafekeeperSchedulingPolicyRequest { scheduling_policy }),
)
.await
}
#[instrument(skip(self))]
pub async fn inspect(
&self,

View File

@@ -14,14 +14,6 @@ PG_VERSION=${PG_VERSION:-14}
CONFIG_FILE_ORG=/var/db/postgres/configs/config.json
CONFIG_FILE=/tmp/config.json
# Test that the first library path that the dynamic loader looks in is the path
# that we use for custom compiled software
first_path="$(ldconfig --verbose 2>/dev/null \
| grep --invert-match ^$'\t' \
| cut --delimiter=: --fields=1 \
| head --lines=1)"
test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat.
echo "Waiting pageserver become ready."
while ! nc -z pageserver 6400; do
sleep 1;

View File

@@ -12,7 +12,6 @@ ERROR: invalid JWT encoding
-- Test creating a session with an expired JWT
SELECT auth.jwt_session_init('eyJhbGciOiJFZERTQSJ9.eyJleHAiOjE3NDI1NjQ0MzIsImlhdCI6MTc0MjU2NDI1MiwianRpIjo0MjQyNDIsInN1YiI6InVzZXIxMjMifQ.A6FwKuaSduHB9O7Gz37g0uoD_U9qVS0JNtT7YABGVgB7HUD1AMFc9DeyhNntWBqncg8k5brv-hrNTuUh5JYMAw');
ERROR: Token used after it has expired
DETAIL: exp=1742564432
-- Test creating a session with a valid JWT
SELECT auth.jwt_session_init('eyJhbGciOiJFZERTQSJ9.eyJleHAiOjQ4OTYxNjQyNTIsImlhdCI6MTc0MjU2NDI1MiwianRpIjo0MzQzNDMsInN1YiI6InVzZXIxMjMifQ.2TXVgjb6JSUq6_adlvp-m_SdOxZSyGS30RS9TLB0xu2N83dMSs2NybwE1NMU8Fb0tcAZR_ET7M2rSxbTrphfCg');
jwt_session_init

View File

@@ -5,4 +5,3 @@ listen_http_addr='0.0.0.0:9898'
remote_storage={ endpoint='http://minio:9000', bucket_name='neon', bucket_region='eu-north-1', prefix_in_bucket='/pageserver' }
control_plane_api='http://0.0.0.0:6666' # No storage controller in docker compose, specify a junk address
control_plane_emergency_mode=true
virtual_file_io_mode="buffered" # the CI runners where we run the docker compose tests have slow disks

View File

@@ -7,8 +7,6 @@ Author: Christian Schwarz
A brief RFC / GitHub Epic describing a vectored version of the `Timeline::get` method that is at the heart of Pageserver.
**EDIT**: the implementation of this feature is described in [Vlad's (internal) tech talk](https://drive.google.com/file/d/1vfY24S869UP8lEUUDHRWKF1AJn8fpWoJ/view?usp=drive_link).
# Motivation
During basebackup, we issue many `Timeline::get` calls for SLRU pages that are *adjacent* in key space.

View File

@@ -1,362 +0,0 @@
# Direct IO For Pageserver
Date: Apr 30, 2025
## Summary
This document is a retroactive RFC. It
- provides some background on what direct IO is,
- motivates why Pageserver should be using it for its IO, and
- describes how we changed Pageserver to use it.
The [initial proposal](https://github.com/neondatabase/neon/pull/8240) that kicked off the work can be found in this closed GitHub PR.
People primarily involved in this project were:
- Yuchen Liang <yuchen@neon.tech>
- Vlad Lazar <vlad@neon.tech>
- Christian Schwarz <christian@neon.tech>
## Timeline
For posterity, here is the rough timeline of the development work that got us to where we are today.
- Jan 2024: [integrate `tokio-epoll-uring`](https://github.com/neondatabase/neon/pull/5824) along with owned buffers API
- March 2024: `tokio-epoll-uring` enabled in all regions in buffered IO mode
- Feb 2024 to June 2024: PS PageCache Bypass For Data Blocks
- Feb 2024: [Vectored Get Implementation](https://github.com/neondatabase/neon/pull/6576) bypasses delta & image layer blocks for page requests
- Apr to June 2024: [Epic: bypass PageCache for use data blocks](https://github.com/neondatabase/neon/issues/7386) addresses remaining users
- Aug to Nov 2024: direct IO: first code; preliminaries; read path coding; BufferedWriter; benchmarks show perf regressions too high, no-go.
- Nov 2024 to Jan 2025: address perf regressions by developing page_service pipelining (aka batching) and concurrent IO ([Epic](https://github.com/neondatabase/neon/issues/9376))
- Feb to March 2024: rollout batching, then concurrent+direct IO => read path and InMemoryLayer is now direct IO
- Apr 2025: develop & roll out direct IO for the write path
## Background: Terminology & Glossary
**kernel page cache**: the Linux kernel's page cache is a write-back cache for filesystem contents.
The cached unit is memory-page-sized & aligned chunks of the files that are being cached (typically 4k).
The cache lives in kernel memory and is not directly accessible through userspace.
**Buffered IO**: an application's read/write system calls go through the kernel page cache.
For example, a 10 byte sized read or write to offset 5000 in a file will load the file contents
at offset `[4096,8192)` into a free page in the kernel page cache. If necessary, it will evict
a page to make room (cf eviction). Then, the kernel performs a memory-to-memory copy of 10 bytes
from/to the offset `4` (`5000 = 4096 + 4`) within the cached page. If it's a write, the kernel keeps
track of the fact that the page is now "dirty" in some ancillary structure.
**Writeback**: a buffered read/write syscall returns after the memory-to-memory copy. The modifications
made by e.g. write system calls are not even *issued* to disk, let alone durable. Instead, the kernel
asynchronously writes back dirtied pages based on a variety of conditions. For us, the most relevant
ones are a) explicit request by userspace (`fsync`) and b) memory pressure.
**Memory pressure**: the kernel page cache is a best effort service and a user of spare memory capacity.
If there is no free memory, the kernel page allocator will take pages used by page cache to satisfy allocations.
Before reusing a page like that, the page has to be written back (writeback, see above).
The far-reaching consequence of this is that **any allocation of anonymous memory can do IO** if the only
way to get that memory is by eviction & re-using a dirty page cache page.
Notably, this includes a simple `malloc` in userspace, because eventually that boils down to `mmap(..., MAP_ANON, ...)`.
I refer to this effect as the "malloc latency backscatter" caused by buffered IO.
**Direct IO** allows application's read/write system calls to bypass the kernel page cache. The filesystem
is still involved because it is ultimately in charge of mapping the concept of files & offsets within them
to sectors on block devices. Typically, the filesystem poses size and alignment requirements for memory buffers
and file offsets (statx `Dio_mem_align` / `Dio_offset_align`), see [this gist](https://gist.github.com/problame/1c35cac41b7cd617779f8aae50f97155).
The IO operations will fail at runtime with EINVAL if the alignment requirements are not met.
**"buffered" vs "direct"**: the central distinction between buffered and direct IO is about who allocates and
fills the IO buffers, and who controls when exactly the IOs are issued. In buffered IO, it's the syscall handlers,
kernel page cache, and memory management subsystems (cf "writeback"). In direct IO, all of it is done by
the application.
It takes more effort by the application to program with direct instead of buffered IO.
The return is precise control over and a clear distinction between consumption/modification of memory vs disk.
**Pageserver PageCache**: Pageserver has an additional `PageCache` (referred to as PS PageCache from here on, as opposed to "kernel page cache").
Its caching unit is 8KiB blocks of the layer files written by Pageserver.
A miss in PageCache is filled by reading from the filesystem, through the `VirtualFile` abstraction layer.
The default size is tiny (64MiB), very much like Postgres's `shared_buffers`.
We ran production at 128MiB for a long time but gradually moved it up to 2GiB over the past ~year.
**VirtualFile** is Pageserver's abstraction for file IO, very similar to the facility in Postgres that bears the same name.
Its historical purpose appears to be working around open file descriptor limitations, which is practically irrelevant on Linux.
However, the facility in Pageserver is useful as an intermediary layer for metrics and abstracts over the different kinds of
IO engines that Pageserver supports (`std-fs` vs `tokio-epoll-uring`).
## Background: History Of Caching In Pageserver
For multiple years, Pageserver's `PageCache` was on the path of all read _and write_ IO.
It performed write-back to the kernel using buffered IO.
We converted it into a read-only cache of immutable data in [PR 4994](https://github.com/neondatabase/neon/pull/4994).
The introduction of `tokio-epoll-uring` required converting the code base to used owned IO buffers.
The `PageCache` pages are usable as owned IO buffers.
We then started bypassing PageCache for user data blocks.
Data blocks are the 8k blocks of data in layer files that hold the multiple `Value`s, as opposed to the disk btree index blocks that tell us which values exist in a file at what offsets.
The disk btree embedded in delta & image layers remains `PageCache`'d.
Epics for that work were:
- Vectored `Timeline::get` (cf RFC 30) skipped delta and image layer data block `PageCache`ing outright.
- Epic https://github.com/neondatabase/neon/issues/7386 took care of the remaining users for data blocks:
- Materialized page cache (cached materialized pages; shown to be ~0% hit rate in practice)
- InMemoryLayer
- Compaction
The outcome of the above:
1. All data blocks are always read through the `VirtualFile` APIs, hitting the kernel buffered read path (=> kernel page cache).
2. Indirect blocks (=disk btree blocks) would be cached in the PS `PageCache`.
In production we size the PS `PageCache` to be 2GiB.
Thus drives hit rate up to ~99.95% and the eviction rate / replacement rates down to less than 200/second on a 1-minute average, on the busiest machines.
High baseline replacement rates are treated as a signal of resource exhaustion (page cache insufficient to host working set of the PS).
The response to this is to migrate tenants away, or increase PS `PageCache` size.
It is currently manual but could be automated, e.g., in Storage Controller.
In the future, we may eliminate the `PageCache` even for indirect blocks.
For example with an LRU cache that has as unit the entire disk btree content
instead of individual blocks.
## High-Level Design
So, before work on this project started, all data block reads and the entire write path of Pageserver were using kernel-buffered IO, i.e., the kernel page cache.
We now want to get the kernel page cache out of the picture by using direct IO for all interaction with the filesystem.
This achieves the following system properties:
**Predictable VirtualFile latencies**
* With buffered IO, reads are sometimes fast, sometimes slow, depending on kernel page cache hit/miss.
* With buffered IO, appends when writing out new layer files during ingest or compaction are sometimes fast, sometimes slow because of write-back backpressure.
* With buffered IO, the "malloc backscatter" phenomenon pointed out in the Glossary section is not something we actively observe.
But we do have occasional spikes in Dirty memory amount and Memory PSI graphs, so it may already be affecting to some degree.
* By switching to direct IO, above operations will have the (predictable) device latency -- always.
Reads and appends always go to disk.
And malloc will not have to write back dirty data.
**Explicitness & Tangibility of resource usage**
* In a multi-tenant system, it is generally desirable and valuable to be *explicit* about the main resources we use for each tenant.
* By using direct IO, we become explicit about the resources *disk IOPs* and *memory capacity* in a way that was previously being conflated through the kernel page cache, outside our immediate control.
* We will be able to build per-tenant observability of resource usage ("what tenant is causing the actual IOs that are sent to the disk?").
* We will be able to build accounting & QoS by implementing an IO scheduler that is tenant aware. The kernel is not tenant-aware and can't do that.
**CPU Efficiency**
* The involvement of the kernel page cache means one additional memory-to-memory copy on read and write path.
* Direct IO will eliminate that memory-to-memory copy, if we can make the userspace buffers used for the IO calls satisfy direct IO alignment requirements.
The **trade-off** is that we no longer get the theoretical benefits of the kernel page cache. These are:
- read latency improvements for repeat reads of the same data ("locality of reference")
- asterisk: only if that state is still cache-resident by time of next access
- write throughput by having kernel page cache batch small VFS writes into bigger disk writes
- asterisk: only if memory pressure is low enough that the kernel can afford to delay writeback
We are **happy to make this trade-off**:
- Because of the advantages listed above.
- Because we empirically have enough DRAM on Pageservers to serve metadata (=index blocks) from PS PageCache.
(At just 2GiB PS PageCache size, we average a 99.95% hit rate).
So, the latency of going to disk is only for data block reads, not the index traversal.
- Because **the kernel page cache is ineffective** at high tenant density anyway (#tenants/pageserver instance).
And because dense packing of tenants will always be desirable to drive COGS down, we should design the system for it.
(See the appendix for a more detailed explanation why this is).
- So, we accept that some reads that used to be fast by circumstance will have higher but **predictable** latency than before.
### Desired End State
The desired end state of the project is as follows, and with some asterisks, we have achieved it.
All IOs of the Pageserver data path use direct IO, thereby bypassing the kernel page cache.
In particular, the "data path" includes
- the wal ingest path
- compaction
- anything on the `Timeline::get` / `Timeline::get_vectored` path.
The production Pageserver config is tuned such that virtually all non-data blocks are cached in the PS PageCache.
Hit rate target is 99.95%.
There are no regressions to ingest latency.
The total "wait-for-disk time" contribution to random getpage request latency is `O(1 read IOP latency)`.
We accomplish that by having a near 100% PS PageCache hit rate so that layer index traversal effectively never needs not wait for IO.
Thereby, it can issue all the data blocks as it traverses the index, and only wait at the end of it (concurrent IO).
The amortized "wait-for-disk time" contribution of this direct IO proposal to a series of sequential getpage requests is `1/32 * read IOP latency` for each getpage request.
We accomplish this by server-side batching of up to 32 reads into a single `Timeline::get_vectored` call.
(This is an ideal world where our batches are full - that's not the case in prod today because of lack of queue depth).
## Design & Implementation
### Prerequisites
A lot of prerequisite work had to happen to enable use of direct IO.
To meet the "wait-for-disk time" requirements from the DoD, we implement for the read path:
- page_service level server-side batching (config field `page_service_pipelining`)
- concurrent IO (config field `get_vectored_concurrent_io`)
The work for both of these these was tracked [in the epic](https://github.com/neondatabase/neon/issues/9376).
Server-side batching will likely be obsoleted by the [#proj-compute-communicator](https://github.com/neondatabase/neon/pull/10799).
The Concurrent IO work is described in retroactive RFC `2025-04-30-pageserver-concurrent-io-on-read-path.md`.
The implementation is relatively brittle and needs further investment, see the `Future Work` section in that RFC.
For the write path, and especially WAL ingest, we need to hide write latency.
We accomplish this by implementing a (`BufferedWriter`) type that does double-buffering: flushes of the filled
buffer happen in a sidecar tokio task while new writes fill a new buffer.
We refactor InMemoryLayer as well as BlobWriter (=> delta and image layer writers) to use this new `BufferedWriter`.
The most comprehensive write-up of this work is in [the PR description](https://github.com/neondatabase/neon/pull/11558).
### Ensuring Adherence to Alignment Requirements
Direct IO puts requirements on
- memory buffer alignment
- io size (=memory buffer size)
- file offset alignment
The requirements are specific to a combination of filesystem/block-device/architecture(hardware page size!).
In Neon production environments we currently use ext4 with Linux 6.1.X on AWS and Azure storage-optimized instances (locally attached NVMe).
Instead of dynamic discovery using `statx`, we statically hard-code 512 bytes as the buffer/offset alignment and size-multiple.
We made this decision because:
- a) it is compatible with all the environments we need to run in
- b) our primary workload can be small-random-read-heavy (we do merge adjacent reads if possible, but the worst case is that all `Value`s that needs to be read are far apart)
- c) 512-byte tail latency on the production instance types is much better than 4k (p99.9: 3x lower, p99.99 5x lower).
- d) hard-coding at compile-time allows us to use the Rust type system to enforce the use of only aligned IO buffers, eliminating a source of runtime errors typically associated with direct IO.
This was [discussed here](https://neondb.slack.com/archives/C07BZ38E6SD/p1725036790965549?thread_ts=1725026845.455259&cid=C07BZ38E6SD).
The new `IoBufAligned` / `IoBufAlignedMut` marker traits indicate that a given buffer meets memory alignment requirements.
All `VirtualFile` APIs and several software layers built on top of them only accept buffers that implement those traits.
Implementors of the marker traits are:
- `IoBuffer` / `IoBufferMut`: used for most reads and writes
- `PageWriteGuardBuf`: for filling PS PageCache pages (index blocks!)
The alignment requirement is infectious; it permeates bottom-up throughout the code base.
We stop the infection at roughly the same layers in the code base where we stopped permeating the
use of owned-buffers-style API for tokio-epoll-uring. The way the stopping works is by introducing
a memory-to-memory copy from/to some unaligned memory location on the stack/current/heap.
The places where we currently stop permeating are sort of arbitrary. For example, it would probably
make sense to replace more usage of `Bytes` that we know holds 8k pages with 8k-sized `IoBuffer`s.
The `IoBufAligned` / `IoBufAlignedMut` types do not protect us from the following types of runtime errors:
- non-adherence to file offset alignment requirements
- non-adherence to io size requirements
The following higher-level constructs ensure we meet the requirements:
- read path: the `ChunkedVectoredReadBuilder` and `mod vectored_dio_read` ensure reads happen at aligned offsets and in appropriate size multiples.
- write path: `BufferedWriter` only writes in multiples of the capacity, at offsets that are `start_offset+N*capacity`; see its doc comment.
Note that these types are used always, regardless of whether direct IO is enabled or not.
There are some cases where this adds unnecessary overhead to buffered IO (e.g. all memcpy's inflated to multiples of 512).
But we could not identify meaningful impact in practice when we shipped these changes while we were still using buffered IO.
### Configuration / Feature Flagging
In the previous section we described how all users of VirtualFile were changed to always adhere to direct IO alignment and size-multiple requirements.
To actually enable direct IO, all we need to do is set the `O_DIRECT` flag in `open` syscalls / io_uring operations.
We set `O_DIRECT` based on:
- the VirtualFile API used to create/open the VirtualFile instance
- the `virtual_file_io_mode` configuration flag
- the OpenOptions `read` and/or `write` flags.
The VirtualFile APIs suffixed with `_v2` are the only ones that _may_ open with `O_DIRECT` depending on the other two factors in above list.
Other APIs never use `O_DIRECT`.
(The name is bad and should really be `_maybe_direct_io`.)
The reason for having new APIs is because all code used VirtualFile but implementation and rollout happened in consecutive phases (read path, InMemoryLayer, write path).
At the VirtualFile level, context on whether an instance of VirtualFile is on read path, InMemoryLayer, or write path is not available.
The `_v2` APIs then check make the decision to set `O_DIRECT` based on the `virtual_file_io_mode` flag and the OpenOptions `read`/`write` flags.
The result is the following runtime behavior:
|what|OpenOptions|`v_f_io_mode`<br/>=`buffered`|`v_f_io_mode`<br/>=`direct`|`v_f_io_mode`<br/>=`direct-rw`|
|-|-|-|-|-|
|`DeltaLayerInner`|read|()|O_DIRECT|O_DIRECT|
|`ImageLayerInner`|read|()|O_DIRECT|O_DIRECT|
|`InMemoryLayer`|read + write|()|()*|O_DIRECT|
|`DeltaLayerWriter`| write | () | () | O_DIRECT |
|`ImageLayerWriter`| write | () | () | O_DIRECT |
|`download_layer_file`|write |()|()|O_DIRECT|
The `InMemoryLayer` is marked with `*` because there was a period when it *did* use O_DIRECT under `=direct`.
That period was when we implemented and shipped the first version of `BufferedWriter`.
We used it in `InMemoryLayer` and `download_layer_file` but it was only sensitive to `v_f_io_mode` in `InMemoryLayer`.
The introduction of `=direct-rw`, and the switch of the remaining write path to `BufferedWriter`, happened later,
in https://github.com/neondatabase/neon/pull/11558.
Note that this way of feature flagging inside VirtualFile makes it less and less a general purpose POSIX file access abstraction.
For example, with `=direct-rw` enabled, it is no longer possible to open a `VirtualFile` without `O_DIRECT`. It'll always be set.
## Correctness Validation
The correctness risks with this project were:
- Memory safety issues in the `IoBuffer` / `IoBufferMut` implementation.
These types expose an API that is largely identical to that of the `bytes` crate and/or Vec.
- Runtime errors (=> downtime / unavailability) because of non-adherence to alignment/size-multiple requirements, resulting in EINVAL on the read path.
We sadly do not have infrastructure to run pageserver under `cargo miri`.
So for memory safety issues, we relied on careful peer review.
We do assert the production-like alignment requirements in testing builds.
However, these asserts were added retroactively.
The actual validation before rollout happened in staging and pre-prod.
We eventually enabled `=direct`/`=direct-rw` for Rust unit tests and the regression test suite.
I cannot recall a single instance of staging/pre-prod/production errors caused by non-adherence to alignment/size-multiple requirements.
Evidently developer testing was good enough.
## Performance Validation
The read path went through a lot of iterations of benchmarking in staging and pre-prod.
The benchmarks in those environments demonstrated performance regressions early in the implementation.
It was actually this performance testing that made us implement batching and concurrent IO to avoid unacceptable regressions.
The write path was much quicker to validate because `bench_ingest` covered all of the (less numerous) access patterns.
## Future Work
There is minor and major follow-up work that can be considered in the future.
Check the (soon-to-be-closed) Epic https://github.com/neondatabase/neon/issues/8130's "Follow-Ups" section for a current list.
Read Path:
- PS PageCache hit rate is crucial to unlock concurrent IO and reasonable latency for random reads generally.
Instead of reactively sizing PS PageCache, we should estimate the required PS PageCache size
and potentially also use that to drive placement decisions of shards from StorageController
https://github.com/neondatabase/neon/issues/9288
- ... unless we get rid of PS PageCache entirely and cache the index block in a more specialized cache.
But even then, an estimation of the working set would be helpful to figure out caching strategy.
Write Path:
- BlobWriter and its users could switch back to a borrowed API https://github.com/neondatabase/neon/issues/10129
- ... unless we want to implement bypass mode for large writes https://github.com/neondatabase/neon/issues/10101
- The `TempVirtualFile` introduced as part of this project could internalize more of the common usage pattern: https://github.com/neondatabase/neon/issues/11692
- Reduce conditional compilation around `virtual_file_io_mode`: https://github.com/neondatabase/neon/issues/11676
Both:
- A performance simulation mode that pads VirtualFile op latencies to typical NVMe latencies, even if the underlying storage is faster.
This would avoid misleadingly good performance on developer systems and in benchmarks on systems that are less busy than production hosts.
However, padding latencies at microsecond scale is non-trivial.
Misc:
- We should finish trimming VirtualFile's scope to be truly limited to core data path read & write.
Abstractions for reading & writing pageserver config, location config, heatmaps, etc, should use
APIs in a different package (`VirtualFile::crashsafe_overwrite` and `VirtualFile::read_to_string`
are good entrypoints for cleanup.) https://github.com/neondatabase/neon/issues/11809
# Appendix
## Why Kernel Page Cache Is Ineffective At Tenant High Density
In the Motivation section, we stated:
> - **The kernel page cache ineffective** at high tenant density anyways (#tenants/pageserver instance).
The reason is that the Pageserver workload sent from Computes is whatever is a Compute cache(s) miss.
That's either sequential scans or random reads.
A random read workload simply causes cache thrashing because a packed Pageserver NVMe drive (`im4gn.2xlarge`) has ~100x more capacity than DRAM available.
It is complete waste to have the kernel page cache cache data blocks in this case.
Sequential read workloads *can* benefit iff those pages have been updated recently (=no image layer yet) and together in time/LSN space.
In such cases, the WAL records of those updates likely sit on the same delta layer block.
When Compute does a sequential scan, it sends a series of single-page requests for these individual pages.
When Pageserver processes the second request in such a series, it goes to the same delta layer block and have a kernel page cache hit.
This dependence on kernel page cache for sequential scan performance is significant, but the solution is at a higher level than generic data block caching.
We can either add a small per-connection LRU cache for such delta layer blocks.
Or we can merge those sequential requests into a larger vectored get request, which is designed to never read a block twice.
This amortizes the read latency for our delta layer block across the vectored get batch size (which currently is up to 32).
There are Pageserver-internal workloads that do sequential access (compaction, image layer generation), but these
1. are not latency-critical and can do batched access outside of the `page_service` protocol constraints (image layer generation)
2. don't actually need to reconstruct images and therefore can use totally different access methods (=> compaction can use k-way merge iterators with their own internal buffering / prefetching).

View File

@@ -1,251 +0,0 @@
# Concurrent IO for Pageserver Read Path
Date: May 6, 2025
## Summary
This document is a retroactive RFC on the Pageserver Concurrent IO work that happened in late 2024 / early 2025.
The gist of it is that Pageserver's `Timeline::get_vectored` now _issues_ the data block read operations against layer files
_as it traverses the layer map_ and only _wait_ once, for all of them, after traversal is complete.
Assuming a good PS PageCache hits on the index blocks during traversal, this drives down the "wait-for-disk" time
contribution down from `random_read_io_latency * O(number_of_values)` to `random_read_io_latency * O(1 + traversal)`.
The motivation for why this work had to happen when it happened was the switch of Pageserver to
- not cache user data blocks in PS PageCache and
- switch to use direct IO.
More context on this are given in complimentary RFC `./rfcs/2025-04-30-direct-io-for-pageserver.md`.
### Refs
- Epic: https://github.com/neondatabase/neon/issues/9378
- Prototyping happened during the Lisbon 2024 Offsite hackathon: https://github.com/neondatabase/neon/pull/9002
- Main implementation PR with good description: https://github.com/neondatabase/neon/issues/9378
Design and implementation by:
- Vlad Lazar <vlad@neon.tech>
- Christian Schwarz <christian@neon.tech>
## Background & Motivation
The Pageserver read path (`Timeline::get_vectored`) consists of two high-level steps:
- Retrieve the delta and image `Value`s required to reconstruct the requested Page@LSN (`Timeline::get_values_reconstruct_data`).
- Pass these values to walredo to reconstruct the page images.
The read path used to be single-key but has been made multi-key some time ago.
([Internal tech talk by Vlad](https://drive.google.com/file/d/1vfY24S869UP8lEUUDHRWKF1AJn8fpWoJ/view?usp=drive_link))
However, for simplicity, most of this doc will explain things in terms of a single key being requested.
The `Value` retrieval step above can be broken down into the following functions:
- **Traversal** of the layer map to figure out which `Value`s from which layer files are required for the page reconstruction.
- **Read IO Planning**: planning of the read IOs that need to be issued to the layer files / filesystem / disk.
The main job here is to coalesce the small value reads into larger filesystem-level read operations.
This layer also takes care of direct IO alignment and size-multiple requirements (cf the RFC for details.)
Check `struct VectoredReadPlanner` and `mod vectored_dio_read` for how it's done.
- **Perform the read IO** using `tokio-epoll-uring`.
Before this project, above functions were sequentially interleaved, meaning:
1. we would advance traversal, ...
2. discover, that we need to read a value, ...
3. read it from disk using `tokio-epoll-uring`, ...
4. goto 1 unless we're done.
This meant that if N `Value`s need to be read to reconstruct a page,
the time we spend waiting for disk will be we `random_read_io_latency * O(number_of_values)`.
## Design
The **traversal** and **read IO Planning** jobs still happen sequentially, layer by layer, as before.
But instead of performing the read IOs inline, we submit the IOs to a concurrent tokio task for execution.
After the last read from the last layer is submitted, we wait for the IOs to complete.
Assuming the filesystem / disk is able to actually process the submitted IOs without queuing,
we arrive at _time spent waiting for disk_ ~ `random_read_io_latency * O(1 + traversal)`.
Note this whole RFC is concerned with the steady state where all layer files required for reconstruction are resident on local NVMe.
Traversal will stall on on-demand layer download if a layer is not yet resident.
It cannot proceed without the layer being resident beccause its next step depends on the contents of the layer index.
### Avoiding Waiting For IO During Traversal
The `traversal` component in above time-spent-waiting-for-disk estimation is dominant and needs to be minimized.
Before this project, traversal needed to perform IOs for the following:
1. The time we are waiting on PS PageCache to page in the visited layers' disk btree index blocks.
2. When visiting a delta layer, reading the data block that contains a `Value` for a requested key,
to determine whether the `Value::will_init` the page and therefore traversal can stop for this key.
The solution for (1) is to raise the PS PageCache size such that the hit rate is practically 100%.
(Check out the `Background: History Of Caching In Pageserver` section in the RFC on Direct IO for more details.)
The solution for (2) is source `will_init` from the disk btree index keys, which fortunately
already encode this bit of information since the introduction of the current storage/layer format.
### Concurrent IOs, Submission & Completion
To separate IO submission from waiting for its completion,
we introduce the notion of an `IoConcurrency` struct through which IOs are issued.
An IO is an opaque future that
- captures the `tx` side of a `oneshot` channel
- performs the read IO by calling `VirtualFile::read_exact_at().await`
- sending the result into the `tx`
Issuing an IO means `Box`ing the future above and handing that `Box` over to the `IoConcurrency` struct.
The traversal code that submits the IO stores the the corresponding `oneshot::Receiver`
in the `VectoredValueReconstructState`, in the the place where we previously stored
the sequentially read `img` and `records` fields.
When we're done with traversal, we wait for all submitted IOs:
for each key, there is a future that awaits all the `oneshot::Receiver`s
for that key, and then calls into walredo to reconstruct the page image.
Walredo is now invoked concurrently for each value instead of sequentially.
Walredo itself remains unchanged.
The spawned IO futures are driven to completion by a sidecar tokio task that
is separate from the task that performs all the layer visiting and spawning of IOs.
That tasks receives the IO futures via an unbounded mpsc channel and
drives them to completion inside a `FuturedUnordered`.
### Error handling, Panics, Cancellation-Safety
There are two error classes during reconstruct data retrieval:
* traversal errors: index lookup, move to next layer, and the like
* value read IO errors
A traversal error fails the entire `get_vectored` request, as before this PR.
A value read error only fails reconstruction of that value.
Panics and dropping of the `get_vectored` future before it completes
leaves the sidecar task running and does not cancel submitted IOs
(see next section for details on sidecar task lifecycle).
All of this is safe, but, today's preference in the team is to close out
all resource usage explicitly if possible, rather than cancelling + forgetting
about it on drop. So, there is warning if we drop a
`VectoredValueReconstructState`/`ValuesReconstructState` that still has uncompleted IOs.
### Sidecar Task Lifecycle
The sidecar tokio task is spawned as part of the `IoConcurrency::spawn_from_conf` struct.
The `IoConcurrency` object acts as a handle through which IO futures are submitted.
The spawned tokio task holds the `Timeline::gate` open.
It is _not_ sensitive to `Timeline::cancel`, but instead to the `IoConcurrency` object being dropped.
Once the `IoConcurrency` struct is dropped, no new IO futures can come in
but already submitted IO futures will be driven to completion regardless.
We _could_ safely stop polling these futures because `tokio-epoll-uring` op futures are cancel-safe.
But the underlying kernel and hardware resources are not magically freed up by that.
So, again, in the interest of closing out all outstanding resource usage, we make timeline shutdown wait for sidecar tasks and their IOs to complete.
Under normal conditions, this should be in the low hundreds of microseconds.
It is advisable to make the `IoConcurrency` as long-lived as possible to minimize the amount of
tokio task churn (=> lower pressure on tokio). Generally this means creating it "high up" in the call stack.
The pain with this is that the `IoConcurrency` reference needs to be propagated "down" to
the (short-lived) functions/scope where we issue the IOs.
We would like to use `RequestContext` for this propagation in the future (issue [here](https://github.com/neondatabase/neon/issues/10460)).
For now, we just add another argument to the relevant code paths.
### Feature Gating
The `IoConcurrency` is an `enum` with two variants: `Sequential` and `SidecarTask`.
The behavior from before this project is available through `IoConcurrency::Sequential`,
which awaits the IO futures in place, without "spawning" or "submitting" them anywhere.
The `get_vectored_concurrent_io` pageserver config variable determines the runtime value,
**except** for the places that use `IoConcurrency::sequential` to get an `IoConcurrency` object.
### Alternatives Explored & Caveats Encountered
A few words on the rationale behind having a sidecar *task* and what
alternatives were considered but abandoned.
#### Why We Need A Sidecar *Task* / Why Just `FuturesUnordered` Doesn't Work
We explored to not have a sidecar task, and instead have a `FuturesUnordered` per
`Timeline::get_vectored`. We would queue all IO futures in it and poll it for the
first time after traversal is complete (i.e., at `collect_pending_ios`).
The obvious disadvantage, but not showstopper, is that we wouldn't be submitting
IOs until traversal is complete.
The showstopper however, is that deadlocks happen if we don't drive the
IO futures to completion independently of the traversal task.
The reason is that both the IO futures and the traversal task may hold _some_,
_and_ try to acquire _more_, shared limited resources.
For example, both the travseral task and IO future may try to acquire
* a `VirtualFile` file descriptor cache slot async mutex (observed during impl)
* a `tokio-epoll-uring` submission slot (observed during impl)
* a `PageCache` slot (currently this is not the case but we may move more code into the IO futures in the future)
#### Why We Don't Do `tokio::task`-per-IO-future
Another option is to spawn a short-lived `tokio::task` for each IO future.
We implemented and benchmarked it during development, but found little
throughput improvement and moderate mean & tail latency degradation.
Concerns about pressure on the tokio scheduler led us to abandon this variant.
## Future Work
In addition to what is listed here, also check the "Punted" list in the epic:
https://github.com/neondatabase/neon/issues/9378
### Enable `Timeline::get`
The only major code path that still uses `IoConcurrency::sequential` is `Timeline::get`.
The impact is that roughly the following parts of pageserver do not benefit yet:
- parts of basebackup
- reads performed by the ingest path
- most internal operations that read metadata keys (e.g. `collect_keyspace`!)
The solution is to propagate `IoConcurrency` via `RequestContext`:https://github.com/neondatabase/neon/issues/10460
The tricky part is to figure out at which level of the code the `IoConcurrency` is spawned (and added to the RequestContext).
Also, propagation via `RequestContext` makes makes it harder to tell during development whether a given
piece of code uses concurrent vs sequential mode: one has to recurisvely walk up the call tree to find the
place that puts the `IoConcurrency` into the `RequestContext`.
We'd have to use `::Sequential` as the conservative default value in a fresh `RequestContext`, and add some
observability to weed out places that fail to enrich with a properly spanwed `IoConcurrency::spawn_from_conf`.
### Concurrent On-Demand Downloads enabled by Detached Indices
As stated earlier, traversal stalls on on-demand download because its next step depends on the contents of the layer index.
Once we have separated indices from data blocks (=> https://github.com/neondatabase/neon/issues/11695)
we will only need to stall if the index is not resident. The download of the data blocks can happen concurrently or in the background. For example:
- Move the `Layer::get_or_maybe_download().await` inside the IO futures.
This goes in the opposite direction of the next "future work" item below, but it's easy to do.
- Serve the IO future directly from object storage and dispatch the layer download
to some other actor, e.g., an actor that is responsible for both downloads & eviction.
### New `tokio-epoll-uring` API That Separates Submission & Wait-For-Completion
Instead of `$op().await` style API, it would be useful to have a different `tokio-epoll-uring` API
that separates enqueuing (without necessarily `io_uring_enter`ing the kernel each time), submission,
and then wait for completion.
The `$op().await` API is too opaque, so we _have_ to stuff it into a `FuturesUnordered`.
A split API as sketched above would allow traversal to ensure an IO operation is enqueued to the kernel/disk (and get back-pressure iff the io_uring squeue is full).
While avoiding spending of CPU cycles on processing of completions while we're still traversing.
The idea gets muddied by the fact that we may self-deadlock if we submit too much without completing.
So, the submission part of the split API needs to process completions if squeue is full.
In any way, this split API is precondition for the bigger issue with the design presented here,
which we dicsuss in the next section.
### Opaque Futures Are Brittle
The use of opaque futures to represent submitted IOs is a clever hack to minimize changes & allow for near-perfect feature-gating.
However, we take on **brittleness** because callers must guarantee that the submitted futures are independent.
By our experience, it is non-trivial to identify or rule out the interdependencies.
See the lengthy doc comment on the `IoConcurrency::spawn_io` method for more details.
The better interface and proper subsystem boundary is a _descriptive_ struct of what needs to be done ("read this range from this VirtualFile into this buffer")
and get back a means to wait for completion.
The subsystem can thereby reason by its own how operations may be related;
unlike today, where the submitted opaque future can do just about anything.

View File

@@ -343,7 +343,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
TimelineId::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
fn token() -> String {
let claims = endpoint_storage::claims::EndpointStorageClaims {
let claims = endpoint_storage::Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
@@ -489,8 +489,16 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
}
fn delete_prefix_token(uri: &str) -> String {
use serde::Serialize;
let parts = uri.split("/").collect::<Vec<&str>>();
let claims = endpoint_storage::claims::DeletePrefixClaims {
#[derive(Serialize)]
struct PrefixClaims {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<endpoint_storage::EndpointId>,
exp: u64,
}
let claims = PrefixClaims {
tenant_id: parts.get(1).map(|c| c.parse().unwrap()).unwrap(),
timeline_id: parts.get(2).map(|c| c.parse().unwrap()),
endpoint_id: parts.get(3).map(ToString::to_string),

View File

@@ -1,52 +0,0 @@
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use utils::id::{EndpointId, TenantId, TimelineId};
/// Claims to add, remove, or retrieve endpoint data. Used by compute_ctl
#[derive(Deserialize, Serialize, PartialEq)]
pub struct EndpointStorageClaims {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub endpoint_id: EndpointId,
pub exp: u64,
}
/// Claims to remove tenant, timeline, or endpoint data. Used by control plane
#[derive(Deserialize, Serialize, PartialEq)]
pub struct DeletePrefixClaims {
pub tenant_id: TenantId,
/// None when tenant is deleted (endpoint_id is also None in this case)
pub timeline_id: Option<TimelineId>,
/// None when timeline is deleted
pub endpoint_id: Option<EndpointId>,
pub exp: u64,
}
impl Display for EndpointStorageClaims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"EndpointClaims(tenant_id={} timeline_id={} endpoint_id={} exp={})",
self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
)
}
}
impl Display for DeletePrefixClaims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DeletePrefixClaims(tenant_id={} timeline_id={} endpoint_id={}, exp={})",
self.tenant_id,
self.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.exp
)
}
}

View File

@@ -1,5 +1,3 @@
pub mod claims;
use crate::claims::{DeletePrefixClaims, EndpointStorageClaims};
use anyhow::Result;
use axum::extract::{FromRequestParts, Path};
use axum::response::{IntoResponse, Response};
@@ -15,7 +13,7 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use utils::id::{EndpointId, TenantId, TimelineId};
use utils::id::{TenantId, TimelineId};
// simplified version of utils::auth::JwtAuth
pub struct JwtAuth {
@@ -81,6 +79,26 @@ pub struct Storage {
pub max_upload_file_limit: usize,
}
pub type EndpointId = String; // If needed, reuse small string from proxy/src/types.rc
#[derive(Deserialize, Serialize, PartialEq)]
pub struct Claims {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub endpoint_id: EndpointId,
pub exp: u64,
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Claims(tenant_id {} timeline_id {} endpoint_id {} exp {})",
self.tenant_id, self.timeline_id, self.endpoint_id, self.exp
)
}
}
#[derive(Deserialize, Serialize)]
struct KeyRequest {
tenant_id: TenantId,
@@ -89,13 +107,6 @@ struct KeyRequest {
path: String,
}
#[derive(Deserialize, Serialize, PartialEq)]
struct PrefixKeyRequest {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<EndpointId>,
}
#[derive(Debug, PartialEq)]
pub struct S3Path {
pub path: RemotePath,
@@ -154,7 +165,7 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: EndpointStorageClaims = state
let claims: Claims = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
@@ -167,7 +178,7 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
path.endpoint_id.clone()
};
let route = EndpointStorageClaims {
let route = Claims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id,
@@ -182,13 +193,38 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
}
}
#[derive(Deserialize, Serialize, PartialEq)]
pub struct PrefixKeyPath {
pub tenant_id: TenantId,
pub timeline_id: Option<TimelineId>,
pub endpoint_id: Option<EndpointId>,
}
impl Display for PrefixKeyPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PrefixKeyPath(tenant_id {} timeline_id {} endpoint_id {})",
self.tenant_id,
self.timeline_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string()),
self.endpoint_id
.as_ref()
.map(ToString::to_string)
.unwrap_or("".to_string())
)
}
}
#[derive(Debug, PartialEq)]
pub struct PrefixS3Path {
pub path: RemotePath,
}
impl From<&DeletePrefixClaims> for PrefixS3Path {
fn from(path: &DeletePrefixClaims) -> Self {
impl From<&PrefixKeyPath> for PrefixS3Path {
fn from(path: &PrefixKeyPath) -> Self {
let timeline_id = path
.timeline_id
.as_ref()
@@ -214,27 +250,21 @@ impl FromRequestParts<Arc<Storage>> for PrefixS3Path {
state: &Arc<Storage>,
) -> Result<Self, Self::Rejection> {
let Path(path) = parts
.extract::<Path<PrefixKeyRequest>>()
.extract::<Path<PrefixKeyPath>>()
.await
.map_err(|e| bad_request(e, "invalid route"))?;
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|e| bad_request(e, "invalid token"))?;
let claims: DeletePrefixClaims = state
let claims: PrefixKeyPath = state
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "invalid token"))?;
let route = DeletePrefixClaims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id: path.endpoint_id,
exp: claims.exp,
};
if route != claims {
return Err(unauthorized(route, claims));
if path != claims {
return Err(unauthorized(path, claims));
}
Ok((&route).into())
Ok((&path).into())
}
}
@@ -267,7 +297,7 @@ mod tests {
#[test]
fn s3_path() {
let auth = EndpointStorageClaims {
let auth = Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
@@ -297,11 +327,10 @@ mod tests {
#[test]
fn prefix_s3_path() {
let mut path = DeletePrefixClaims {
let mut path = PrefixKeyPath {
tenant_id: TENANT_ID,
timeline_id: None,
endpoint_id: None,
exp: 0,
};
let prefix_path = |s: String| RemotePath::from_string(&s).unwrap();
assert_eq!(

View File

@@ -1,58 +1,16 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use crate::privilege::Privilege;
use crate::responses::ComputeCtlConfig;
use crate::spec::{ComputeSpec, ExtVersion, PgIdent};
/// The value to place in the [`ComputeClaims::audience`] claim.
pub static COMPUTE_AUDIENCE: &str = "compute";
/// Available scopes for a compute's JWT.
#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ComputeClaimsScope {
/// An admin-scoped token allows access to all of `compute_ctl`'s authorized
/// facilities.
Admin,
}
impl FromStr for ComputeClaimsScope {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"admin" => Ok(ComputeClaimsScope::Admin),
_ => Err(anyhow::anyhow!("invalid compute claims scope \"{s}\"")),
}
}
}
/// When making requests to the `compute_ctl` external HTTP server, the client
/// must specify a set of claims in `Authorization` header JWTs such that
/// `compute_ctl` can authorize the request.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename = "snake_case")]
pub struct ComputeClaims {
/// The compute ID that will validate the token. The only case in which this
/// can be [`None`] is if [`Self::scope`] is
/// [`ComputeClaimsScope::Admin`].
pub compute_id: Option<String>,
/// The scope of what the token authorizes.
pub scope: Option<ComputeClaimsScope>,
/// The recipient the token is intended for.
///
/// See [RFC 7519](https://www.rfc-editor.org/rfc/rfc7519#section-4.1.3) for
/// more information.
///
/// TODO: Remove the [`Option`] wrapper when control plane learns to send
/// the claim.
#[serde(rename = "aud")]
pub audience: Option<Vec<String>>,
pub compute_id: String,
}
/// Request of the /configure API

View File

@@ -46,30 +46,6 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
#[default]
NotPrewarmed,
Prewarming,
Completed,
Failed {
error: String,
},
}
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcOffloadState {
#[default]
NotOffloaded,
Offloading,
Completed,
Failed {
error: String,
},
}
/// Response of the /status API
#[derive(Serialize, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]

View File

@@ -172,15 +172,6 @@ pub struct ComputeSpec {
/// Hostname and the port of the otel collector. Leave empty to disable Postgres logs forwarding.
/// Example: config-shy-breeze-123-collector-monitoring.neon-telemetry.svc.cluster.local:10514
pub logs_export_host: Option<String>,
/// Address of endpoint storage service
pub endpoint_storage_addr: Option<String>,
/// JWT for authorizing requests to endpoint storage service
pub endpoint_storage_token: Option<String>,
/// If true, download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
pub prewarm_lfc_on_startup: bool,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.

View File

@@ -84,11 +84,6 @@
"value": "on",
"vartype": "bool"
},
{
"name": "prewarm_lfc_on_startup",
"value": "off",
"vartype": "bool"
},
{
"name": "neon.safekeepers",
"value": "127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501",

View File

@@ -16,7 +16,6 @@ pub struct Collector {
const NMETRICS: usize = 2;
static CLK_TCK_F64: Lazy<f64> = Lazy::new(|| {
// SAFETY: libc::sysconf is safe, it merely returns a value.
let long = unsafe { libc::sysconf(libc::_SC_CLK_TCK) };
if long == -1 {
panic!("sysconf(_SC_CLK_TCK) failed");

View File

@@ -182,7 +182,6 @@ pub struct ConfigToml {
pub tracing: Option<Tracing>,
pub enable_tls_page_service_api: bool,
pub dev_mode: bool,
pub timeline_import_config: TimelineImportConfig,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -301,12 +300,6 @@ impl From<OtelExporterProtocol> for tracing_utils::Protocol {
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TimelineImportConfig {
pub import_job_concurrency: NonZeroUsize,
pub import_job_soft_size_limit: NonZeroUsize,
}
pub mod statvfs {
pub mod mock {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -666,10 +659,6 @@ impl Default for ConfigToml {
tracing: None,
enable_tls_page_service_api: false,
dev_mode: false,
timeline_import_config: TimelineImportConfig {
import_job_concurrency: NonZeroUsize::new(128).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(),
},
}
}
}

View File

@@ -1832,7 +1832,6 @@ pub mod virtual_file {
Eq,
Hash,
strum_macros::EnumString,
strum_macros::EnumIter,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
@@ -1844,8 +1843,10 @@ pub mod virtual_file {
/// Uses buffered IO.
Buffered,
/// Uses direct IO for reads only.
#[cfg(target_os = "linux")]
Direct,
/// Use direct IO for reads and writes.
#[cfg(target_os = "linux")]
DirectRw,
}
@@ -1853,13 +1854,26 @@ pub mod virtual_file {
pub fn preferred() -> Self {
// The default behavior when running Rust unit tests without any further
// flags is to use the newest behavior (DirectRw).
// The CI uses the environment variable to unit tests for all different modes.
// The CI uses the following environment variable to unit tests for all
// different modes.
// NB: the Python regression & perf tests have their own defaults management
// that writes pageserver.toml; they do not use this variable.
static ENV_OVERRIDE: LazyLock<Option<IoMode>> = LazyLock::new(|| {
utils::env::var_serde_json_string("NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IO_MODE")
});
ENV_OVERRIDE.unwrap_or(IoMode::DirectRw)
if cfg!(test) {
static CACHED: LazyLock<IoMode> = LazyLock::new(|| {
utils::env::var_serde_json_string(
"NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IO_MODE",
)
.unwrap_or(
#[cfg(target_os = "linux")]
IoMode::DirectRw,
#[cfg(not(target_os = "linux"))]
IoMode::Buffered,
)
});
*CACHED
} else {
IoMode::Buffered
}
}
}
@@ -1869,7 +1883,9 @@ pub mod virtual_file {
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (IoMode::Buffered as u8) => IoMode::Buffered,
#[cfg(target_os = "linux")]
v if v == (IoMode::Direct as u8) => IoMode::Direct,
#[cfg(target_os = "linux")]
v if v == (IoMode::DirectRw as u8) => IoMode::DirectRw,
x => return Err(x),
})

View File

@@ -299,7 +299,6 @@ pub struct PullTimelineRequest {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub http_hosts: Vec<String>,
pub ignore_tombstone: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]

View File

@@ -295,9 +295,6 @@ pub struct TenantId(Id);
id_newtype!(TenantId);
/// If needed, reuse small string from proxy/src/types.rc
pub type EndpointId = String;
// A pair uniquely identifying Neon instance.
#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TenantTimelineId {

View File

@@ -17,7 +17,7 @@ impl std::fmt::Display for RateLimitStats {
}
impl RateLimit {
pub const fn new(interval: Duration) -> Self {
pub fn new(interval: Duration) -> Self {
Self {
last: None,
interval,

View File

@@ -14,7 +14,6 @@ use pageserver_api::key::Key;
use pageserver_api::models::virtual_file::IoMode;
use pageserver_api::shard::TenantShardId;
use pageserver_api::value::Value;
use strum::IntoEnumIterator;
use tokio_util::sync::CancellationToken;
use utils::bin_ser::BeSer;
use utils::id::{TenantId, TimelineId};
@@ -245,7 +244,13 @@ fn criterion_benchmark(c: &mut Criterion) {
];
let exploded_parameters = {
let mut out = Vec::new();
for io_mode in IoMode::iter() {
for io_mode in [
IoMode::Buffered,
#[cfg(target_os = "linux")]
IoMode::Direct,
#[cfg(target_os = "linux")]
IoMode::DirectRw,
] {
for param in expect.clone() {
let HandPickedParameters {
volume_mib,

View File

@@ -230,8 +230,6 @@ pub struct PageServerConf {
/// such as authentication requirements for HTTP and PostgreSQL APIs.
/// This is insecure and should only be used in development environments.
pub dev_mode: bool,
pub timeline_import_config: pageserver_api::config::TimelineImportConfig,
}
/// Token for authentication to safekeepers
@@ -406,7 +404,6 @@ impl PageServerConf {
tracing,
enable_tls_page_service_api,
dev_mode,
timeline_import_config,
} = config_toml;
let mut conf = PageServerConf {
@@ -460,7 +457,6 @@ impl PageServerConf {
tracing,
enable_tls_page_service_api,
dev_mode,
timeline_import_config,
// ------------------------------------------------------------
// fields that require additional validation or custom handling

View File

@@ -1038,23 +1038,21 @@ impl PageServerHandler {
tracing::info_span!(
parent: &parent_span,
"handle_get_page_request",
request_id = %req.hdr.reqid,
rel = %req.rel,
blkno = %req.blkno,
req_lsn = %req.hdr.request_lsn,
not_modified_since_lsn = %req.hdr.not_modified_since,
not_modified_since_lsn = %req.hdr.not_modified_since
)
}};
($shard_id:expr) => {{
tracing::info_span!(
parent: &parent_span,
"handle_get_page_request",
request_id = %req.hdr.reqid,
rel = %req.rel,
blkno = %req.blkno,
req_lsn = %req.hdr.request_lsn,
not_modified_since_lsn = %req.hdr.not_modified_since,
shard_id = %$shard_id,
shard_id = %$shard_id
)
}};
}

View File

@@ -40,7 +40,7 @@ use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
use super::tenant::{PageReconstructError, Timeline};
use crate::aux_file;
use crate::context::{PerfInstrumentFutureExt, RequestContext, RequestContextBuilder};
use crate::context::{PerfInstrumentFutureExt, RequestContext};
use crate::keyspace::{KeySpace, KeySpaceAccum};
use crate::metrics::{
RELSIZE_CACHE_ENTRIES, RELSIZE_CACHE_HITS, RELSIZE_CACHE_MISSES, RELSIZE_CACHE_MISSES_OLD,
@@ -275,30 +275,24 @@ impl Timeline {
continue;
}
let nblocks = {
let ctx = RequestContextBuilder::from(&ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"GET_REL_SIZE",
reltag=%tag,
lsn=%lsn,
)
})
.attached_child();
match self
.get_rel_size(*tag, Version::Lsn(lsn), &ctx)
.maybe_perf_instrument(&ctx, |crnt_perf_span| crnt_perf_span.clone())
.await
{
Ok(nblocks) => nblocks,
Err(err) => {
result_slots[response_slot_idx].write(Err(err));
slots_filled += 1;
continue;
}
let nblocks = match self
.get_rel_size(*tag, Version::Lsn(lsn), &ctx)
.maybe_perf_instrument(&ctx, |crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"GET_REL_SIZE",
reltag=%tag,
lsn=%lsn,
)
})
.await
{
Ok(nblocks) => nblocks,
Err(err) => {
result_slots[response_slot_idx].write(Err(err));
slots_filled += 1;
continue;
}
};
@@ -314,17 +308,6 @@ impl Timeline {
let key = rel_block_to_key(*tag, *blknum);
let ctx = RequestContextBuilder::from(&ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"GET_BATCH",
batch_size = %page_count,
)
})
.attached_child();
let key_slots = keys_slots.entry(key).or_default();
key_slots.push((response_slot_idx, ctx));
@@ -340,7 +323,14 @@ impl Timeline {
let query = VersionedKeySpaceQuery::scattered(query);
let res = self
.get_vectored(query, io_concurrency, ctx)
.maybe_perf_instrument(ctx, |current_perf_span| current_perf_span.clone())
.maybe_perf_instrument(ctx, |current_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: current_perf_span,
"GET_BATCH",
batch_size = %page_count,
)
})
.await;
match res {

View File

@@ -94,23 +94,10 @@ impl Header {
pub enum WriteBlobError {
#[error(transparent)]
Flush(FlushTaskError),
#[error("blob too large ({len} bytes)")]
BlobTooLarge { len: usize },
#[error(transparent)]
Other(anyhow::Error),
}
impl WriteBlobError {
pub fn is_cancel(&self) -> bool {
match self {
WriteBlobError::Flush(e) => e.is_cancel(),
WriteBlobError::Other(_) => false,
}
}
pub fn into_anyhow(self) -> anyhow::Error {
match self {
WriteBlobError::Flush(e) => e.into_anyhow(),
WriteBlobError::Other(e) => e,
}
}
WriteBlobRaw(anyhow::Error),
}
impl BlockCursor<'_> {
@@ -340,9 +327,7 @@ where
return (
(
io_buf.slice_len(),
Err(WriteBlobError::Other(anyhow::anyhow!(
"blob too large ({len} bytes)"
))),
Err(WriteBlobError::BlobTooLarge { len }),
),
srcbuf,
);
@@ -406,7 +391,7 @@ where
// Verify the header, to ensure we don't write invalid/corrupt data.
let header = match Header::decode(&raw_with_header)
.context("decoding blob header")
.map_err(WriteBlobError::Other)
.map_err(WriteBlobError::WriteBlobRaw)
{
Ok(header) => header,
Err(err) => return (raw_with_header, Err(err)),
@@ -416,7 +401,7 @@ where
let raw_len = raw_with_header.len();
return (
raw_with_header,
Err(WriteBlobError::Other(anyhow::anyhow!(
Err(WriteBlobError::WriteBlobRaw(anyhow::anyhow!(
"header length mismatch: {header_total_len} != {raw_len}"
))),
);

View File

@@ -2,7 +2,6 @@
pub mod batch_split_writer;
pub mod delta_layer;
pub mod errors;
pub mod filter_iterator;
pub mod image_layer;
pub mod inmemory_layer;

View File

@@ -10,7 +10,6 @@ use utils::id::TimelineId;
use utils::lsn::Lsn;
use utils::shard::TenantShardId;
use super::errors::PutError;
use super::layer::S3_UPLOAD_LIMIT;
use super::{
DeltaLayerWriter, ImageLayerWriter, PersistentLayerDesc, PersistentLayerKey, ResidentLayer,
@@ -236,7 +235,7 @@ impl<'a> SplitImageLayerWriter<'a> {
key: Key,
img: Bytes,
ctx: &RequestContext,
) -> Result<(), PutError> {
) -> anyhow::Result<()> {
// The current estimation is an upper bound of the space that the key/image could take
// because we did not consider compression in this estimation. The resulting image layer
// could be smaller than the target size.
@@ -254,8 +253,7 @@ impl<'a> SplitImageLayerWriter<'a> {
self.cancel.clone(),
ctx,
)
.await
.map_err(PutError::Other)?;
.await?;
let prev_image_writer = std::mem::replace(&mut self.inner, next_image_writer);
self.batches.add_unfinished_image_writer(
prev_image_writer,
@@ -348,7 +346,7 @@ impl<'a> SplitDeltaLayerWriter<'a> {
lsn: Lsn,
val: Value,
ctx: &RequestContext,
) -> Result<(), PutError> {
) -> anyhow::Result<()> {
// The current estimation is key size plus LSN size plus value size estimation. This is not an accurate
// number, and therefore the final layer size could be a little bit larger or smaller than the target.
//
@@ -368,8 +366,7 @@ impl<'a> SplitDeltaLayerWriter<'a> {
self.cancel.clone(),
ctx,
)
.await
.map_err(PutError::Other)?,
.await?,
));
}
let (_, inner) = self.inner.as_mut().unwrap();
@@ -389,8 +386,7 @@ impl<'a> SplitDeltaLayerWriter<'a> {
self.cancel.clone(),
ctx,
)
.await
.map_err(PutError::Other)?;
.await?;
let (start_key, prev_delta_writer) =
self.inner.replace((key, next_delta_writer)).unwrap();
self.batches.add_unfinished_delta_writer(
@@ -400,11 +396,11 @@ impl<'a> SplitDeltaLayerWriter<'a> {
);
} else if inner.estimated_size() >= S3_UPLOAD_LIMIT {
// We have to produce a very large file b/c a key is updated too often.
return Err(PutError::Other(anyhow::anyhow!(
anyhow::bail!(
"a single key is updated too often: key={}, estimated_size={}, and the layer file cannot be produced",
key,
inner.estimated_size()
)));
);
}
}
self.last_key_written = key;

View File

@@ -55,7 +55,6 @@ use utils::bin_ser::SerializeError;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use super::errors::PutError;
use super::{
AsLayerDesc, LayerName, OnDiskValue, OnDiskValueIo, PersistentLayerDesc, ResidentLayer,
ValuesReconstructState,
@@ -478,15 +477,12 @@ impl DeltaLayerWriterInner {
lsn: Lsn,
val: Value,
ctx: &RequestContext,
) -> Result<(), PutError> {
) -> anyhow::Result<()> {
let (_, res) = self
.put_value_bytes(
key,
lsn,
Value::ser(&val)
.map_err(anyhow::Error::new)
.map_err(PutError::Other)?
.slice_len(),
Value::ser(&val)?.slice_len(),
val.will_init(),
ctx,
)
@@ -501,7 +497,7 @@ impl DeltaLayerWriterInner {
val: FullSlice<Buf>,
will_init: bool,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), PutError>)
) -> (FullSlice<Buf>, anyhow::Result<()>)
where
Buf: IoBuf + Send,
{
@@ -517,24 +513,19 @@ impl DeltaLayerWriterInner {
.blob_writer
.write_blob_maybe_compressed(val, ctx, compression)
.await;
let res = res.map_err(PutError::WriteBlob);
let off = match res {
Ok((off, _)) => off,
Err(e) => return (val, Err(e)),
Err(e) => return (val, Err(anyhow::anyhow!(e))),
};
let blob_ref = BlobRef::new(off, will_init);
let delta_key = DeltaKey::from_key_lsn(&key, lsn);
let res = self
.tree
.append(&delta_key.0, blob_ref.0)
.map_err(anyhow::Error::new)
.map_err(PutError::Other);
let res = self.tree.append(&delta_key.0, blob_ref.0);
self.num_keys += 1;
(val, res)
(val, res.map_err(|e| anyhow::anyhow!(e)))
}
fn size(&self) -> u64 {
@@ -703,7 +694,7 @@ impl DeltaLayerWriter {
lsn: Lsn,
val: Value,
ctx: &RequestContext,
) -> Result<(), PutError> {
) -> anyhow::Result<()> {
self.inner
.as_mut()
.unwrap()
@@ -718,7 +709,7 @@ impl DeltaLayerWriter {
val: FullSlice<Buf>,
will_init: bool,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), PutError>)
) -> (FullSlice<Buf>, anyhow::Result<()>)
where
Buf: IoBuf + Send,
{
@@ -1450,6 +1441,14 @@ impl DeltaLayerInner {
offset
}
pub fn iter<'a>(&'a self, ctx: &'a RequestContext) -> DeltaLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
@@ -1635,6 +1634,7 @@ pub(crate) mod test {
use crate::tenant::disk_btree::tests::TestDisk;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::vectored_blob_io::StreamingVectoredReadPlanner;
use crate::tenant::{TenantShard, Timeline};
/// Construct an index for a fictional delta layer and and then
@@ -2311,7 +2311,8 @@ pub(crate) mod test {
for batch_size in [1, 2, 4, 8, 3, 7, 13] {
println!("running with batch_size={batch_size} max_read_size={max_read_size}");
// Test if the batch size is correctly determined
let mut iter = delta_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = delta_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
let mut num_items = 0;
for _ in 0..3 {
iter.next_batch().await.unwrap();
@@ -2328,7 +2329,8 @@ pub(crate) mod test {
iter.key_values_batch.clear();
}
// Test if the result is correct
let mut iter = delta_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = delta_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
assert_delta_iter_equal(&mut iter, &test_deltas).await;
}
}

View File

@@ -1,24 +0,0 @@
use crate::tenant::blob_io::WriteBlobError;
#[derive(Debug, thiserror::Error)]
pub enum PutError {
#[error(transparent)]
WriteBlob(WriteBlobError),
#[error(transparent)]
Other(anyhow::Error),
}
impl PutError {
pub fn is_cancel(&self) -> bool {
match self {
PutError::WriteBlob(e) => e.is_cancel(),
PutError::Other(_) => false,
}
}
pub fn into_anyhow(self) -> anyhow::Error {
match self {
PutError::WriteBlob(e) => e.into_anyhow(),
PutError::Other(e) => e,
}
}
}

View File

@@ -157,7 +157,7 @@ mod tests {
.await
.unwrap();
let merge_iter = MergeIterator::create_for_testing(
let merge_iter = MergeIterator::create(
&[resident_layer_1.get_as_delta(&ctx).await.unwrap()],
&[],
&ctx,
@@ -182,7 +182,7 @@ mod tests {
result.extend(test_deltas1[90..100].iter().cloned());
assert_filter_iter_equal(&mut filter_iter, &result).await;
let merge_iter = MergeIterator::create_for_testing(
let merge_iter = MergeIterator::create(
&[resident_layer_1.get_as_delta(&ctx).await.unwrap()],
&[],
&ctx,

View File

@@ -53,7 +53,6 @@ use utils::bin_ser::SerializeError;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use super::errors::PutError;
use super::layer_name::ImageLayerName;
use super::{
AsLayerDesc, LayerName, OnDiskValue, OnDiskValueIo, PersistentLayerDesc, ResidentLayer,
@@ -685,6 +684,14 @@ impl ImageLayerInner {
}
}
pub(crate) fn iter<'a>(&'a self, ctx: &'a RequestContext) -> ImageLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub(crate) fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
@@ -843,14 +850,8 @@ impl ImageLayerWriterInner {
key: Key,
img: Bytes,
ctx: &RequestContext,
) -> Result<(), PutError> {
if !self.key_range.contains(&key) {
return Err(PutError::Other(anyhow::anyhow!(
"key {:?} not in range {:?}",
key,
self.key_range
)));
}
) -> anyhow::Result<()> {
ensure!(self.key_range.contains(&key));
let compression = self.conf.image_compression;
let uncompressed_len = img.len() as u64;
self.uncompressed_bytes += uncompressed_len;
@@ -860,7 +861,7 @@ impl ImageLayerWriterInner {
.write_blob_maybe_compressed(img.slice_len(), ctx, compression)
.await;
// TODO: re-use the buffer for `img` further upstack
let (off, compression_info) = res.map_err(PutError::WriteBlob)?;
let (off, compression_info) = res?;
if compression_info.compressed_size.is_some() {
// The image has been considered for compression at least
self.uncompressed_bytes_eligible += uncompressed_len;
@@ -872,10 +873,7 @@ impl ImageLayerWriterInner {
let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE];
key.write_to_byte_slice(&mut keybuf);
self.tree
.append(&keybuf, off)
.map_err(anyhow::Error::new)
.map_err(PutError::Other)?;
self.tree.append(&keybuf, off)?;
#[cfg(feature = "testing")]
{
@@ -1095,7 +1093,7 @@ impl ImageLayerWriter {
key: Key,
img: Bytes,
ctx: &RequestContext,
) -> Result<(), PutError> {
) -> anyhow::Result<()> {
self.inner.as_mut().unwrap().put_image(key, img, ctx).await
}
@@ -1242,6 +1240,7 @@ mod test {
use crate::context::RequestContext;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::vectored_blob_io::StreamingVectoredReadPlanner;
use crate::tenant::{TenantShard, Timeline};
#[tokio::test]
@@ -1508,7 +1507,8 @@ mod test {
for batch_size in [1, 2, 4, 8, 3, 7, 13] {
println!("running with batch_size={batch_size} max_read_size={max_read_size}");
// Test if the batch size is correctly determined
let mut iter = img_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = img_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
let mut num_items = 0;
for _ in 0..3 {
iter.next_batch().await.unwrap();
@@ -1525,7 +1525,8 @@ mod test {
iter.key_values_batch.clear();
}
// Test if the result is correct
let mut iter = img_layer.iter_with_options(&ctx, max_read_size, batch_size);
let mut iter = img_layer.iter(&ctx);
iter.planner = StreamingVectoredReadPlanner::new(max_read_size, batch_size);
assert_img_iter_equal(&mut iter, &test_imgs, Lsn(0x10)).await;
}
}

View File

@@ -23,7 +23,7 @@ use super::{
LayerVisibilityHint, PerfInstrumentFutureExt, PersistentLayerDesc, ValuesReconstructState,
};
use crate::config::PageServerConf;
use crate::context::{RequestContext, RequestContextBuilder};
use crate::context::{DownloadBehavior, RequestContext, RequestContextBuilder};
use crate::span::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::task_mgr::TaskKind;
use crate::tenant::Timeline;
@@ -1076,17 +1076,24 @@ impl LayerInner {
return Err(DownloadError::DownloadRequired);
}
let ctx = RequestContextBuilder::from(ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"DOWNLOAD_LAYER",
layer = %self,
reason = %reason,
)
})
.attached_child();
let ctx = if ctx.has_perf_span() {
let dl_ctx = RequestContextBuilder::from(ctx)
.task_kind(TaskKind::LayerDownload)
.download_behavior(DownloadBehavior::Download)
.root_perf_span(|| {
info_span!(
target: PERF_TRACE_TARGET,
"DOWNLOAD_LAYER",
layer = %self,
reason = %reason
)
})
.detached_child();
ctx.perf_follows_from(&dl_ctx);
dl_ctx
} else {
ctx.attached_child()
};
async move {
tracing::info!(%reason, "downloading on-demand");
@@ -1094,7 +1101,7 @@ impl LayerInner {
let init_cancelled = scopeguard::guard((), |_| LAYER_IMPL_METRICS.inc_init_cancelled());
let res = self
.download_init_and_wait(timeline, permit, ctx.attached_child())
.maybe_perf_instrument(&ctx, |current_perf_span| current_perf_span.clone())
.maybe_perf_instrument(&ctx, |crnt_perf_span| crnt_perf_span.clone())
.await?;
scopeguard::ScopeGuard::into_inner(init_cancelled);
@@ -1702,7 +1709,7 @@ impl DownloadError {
}
}
#[derive(Debug, PartialEq, Copy, Clone)]
#[derive(Debug, PartialEq)]
pub(crate) enum NeedsDownload {
NotFound,
NotFile(std::fs::FileType),

View File

@@ -19,6 +19,14 @@ pub(crate) enum LayerRef<'a> {
}
impl<'a> LayerRef<'a> {
#[allow(dead_code)]
fn iter(self, ctx: &'a RequestContext) -> LayerIterRef<'a> {
match self {
Self::Image(x) => LayerIterRef::Image(x.iter(ctx)),
Self::Delta(x) => LayerIterRef::Delta(x.iter(ctx)),
}
}
fn iter_with_options(
self,
ctx: &'a RequestContext,
@@ -314,28 +322,6 @@ impl MergeIteratorItem for ((Key, Lsn, Value), Arc<PersistentLayerKey>) {
}
impl<'a> MergeIterator<'a> {
#[cfg(test)]
pub(crate) fn create_for_testing(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
) -> Self {
Self::create_with_options(deltas, images, ctx, 1024 * 8192, 1024)
}
/// Create a new merge iterator with custom options.
///
/// Adjust `max_read_size` and `max_batch_size` to trade memory usage for performance. The size should scale
/// with the number of layers to compact. If there are a lot of layers, consider reducing the values, so that
/// the buffer does not take too much memory.
///
/// The default options for L0 compactions are:
/// - max_read_size: 1024 * 8192 (8MB)
/// - max_batch_size: 1024
///
/// The default options for gc-compaction are:
/// - max_read_size: 128 * 8192 (1MB)
/// - max_batch_size: 128
pub fn create_with_options(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
@@ -365,6 +351,14 @@ impl<'a> MergeIterator<'a> {
}
}
pub fn create(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
) -> Self {
Self::create_with_options(deltas, images, ctx, 1024 * 8192, 1024)
}
pub(crate) async fn next_inner<R: MergeIteratorItem>(&mut self) -> anyhow::Result<Option<R>> {
while let Some(mut iter) = self.heap.peek_mut() {
if !iter.is_loaded() {
@@ -483,7 +477,7 @@ mod tests {
let resident_layer_2 = produce_delta_layer(&tenant, &tline, test_deltas2.clone(), &ctx)
.await
.unwrap();
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_2.get_as_delta(&ctx).await.unwrap(),
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
@@ -555,7 +549,7 @@ mod tests {
let resident_layer_3 = produce_delta_layer(&tenant, &tline, test_deltas3.clone(), &ctx)
.await
.unwrap();
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
resident_layer_2.get_as_delta(&ctx).await.unwrap(),
@@ -676,7 +670,7 @@ mod tests {
// Test with different layer order for MergeIterator::create to ensure the order
// is stable.
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_4.get_as_delta(&ctx).await.unwrap(),
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
@@ -688,7 +682,7 @@ mod tests {
);
assert_merge_iter_equal(&mut merge_iter, &expect).await;
let mut merge_iter = MergeIterator::create_for_testing(
let mut merge_iter = MergeIterator::create(
&[
resident_layer_1.get_as_delta(&ctx).await.unwrap(),
resident_layer_4.get_as_delta(&ctx).await.unwrap(),

View File

@@ -340,7 +340,7 @@ pub(crate) fn log_compaction_error(
} else {
match level {
Level::ERROR if degrade_to_warning => warn!("Compaction failed and discarded: {err:#}"),
Level::ERROR => error!("Compaction failed: {err:?}"),
Level::ERROR => error!("Compaction failed: {err:#}"),
Level::INFO => info!("Compaction failed: {err:#}"),
level => unimplemented!("unexpected level {level:?}"),
}

View File

@@ -987,16 +987,6 @@ impl From<PageReconstructError> for CreateImageLayersError {
}
}
impl From<super::storage_layer::errors::PutError> for CreateImageLayersError {
fn from(e: super::storage_layer::errors::PutError) -> Self {
if e.is_cancel() {
CreateImageLayersError::Cancelled
} else {
CreateImageLayersError::Other(e.into_anyhow())
}
}
}
impl From<GetVectoredError> for CreateImageLayersError {
fn from(e: GetVectoredError) -> Self {
match e {
@@ -2127,14 +2117,22 @@ impl Timeline {
debug_assert_current_span_has_tenant_and_timeline_id();
// Regardless of whether we're going to try_freeze_and_flush
// or not, stop ingesting any more data.
// or not, stop ingesting any more data. Walreceiver only provides
// cancellation but no "wait until gone", because it uses the Timeline::gate.
// So, only after the self.gate.close() below will we know for sure that
// no walreceiver tasks are left.
// For `try_freeze_and_flush=true`, this means that we might still be ingesting
// data during the call to `self.freeze_and_flush()` below.
// That's not ideal, but, we don't have the concept of a ChildGuard,
// which is what we'd need to properly model early shutdown of the walreceiver
// task sub-tree before the other Timeline task sub-trees.
let walreceiver = self.walreceiver.lock().unwrap().take();
tracing::debug!(
is_some = walreceiver.is_some(),
"Waiting for WalReceiverManager..."
);
if let Some(walreceiver) = walreceiver {
walreceiver.shutdown().await;
walreceiver.cancel();
}
// ... and inform any waiters for newer LSNs that there won't be any.
self.last_record_lsn.shutdown();
@@ -5925,16 +5923,6 @@ impl From<layer_manager::Shutdown> for CompactionError {
}
}
impl From<super::storage_layer::errors::PutError> for CompactionError {
fn from(e: super::storage_layer::errors::PutError) -> Self {
if e.is_cancel() {
CompactionError::ShuttingDown
} else {
CompactionError::Other(e.into_anyhow())
}
}
}
#[serde_as]
#[derive(serde::Serialize)]
struct RecordedDuration(#[serde_as(as = "serde_with::DurationMicroSeconds")] Duration);

View File

@@ -1277,8 +1277,6 @@ impl Timeline {
return Ok(CompactionOutcome::YieldForL0);
}
let gc_cutoff = *self.applied_gc_cutoff_lsn.read();
// 2. Repartition and create image layers if necessary
match self
.repartition(
@@ -1289,7 +1287,7 @@ impl Timeline {
)
.await
{
Ok(((dense_partitioning, sparse_partitioning), lsn)) if lsn >= gc_cutoff => {
Ok(((dense_partitioning, sparse_partitioning), lsn)) => {
// Disables access_stats updates, so that the files we read remain candidates for eviction after we're done with them
let image_ctx = RequestContextBuilder::from(ctx)
.access_stats_behavior(AccessStatsBehavior::Skip)
@@ -1343,10 +1341,6 @@ impl Timeline {
}
}
Ok(_) => {
info!("skipping repartitioning due to image compaction LSN being below GC cutoff");
}
// Suppress errors when cancelled.
Err(_) if self.cancel.is_cancelled() => {}
Err(err) if err.is_cancel() => {}
@@ -2000,13 +1994,7 @@ impl Timeline {
let l = l.get_as_delta(ctx).await.map_err(CompactionError::Other)?;
deltas.push(l);
}
MergeIterator::create_with_options(
&deltas,
&[],
ctx,
1024 * 8192, /* 8 MiB buffer per layer iterator */
1024,
)
MergeIterator::create(&deltas, &[], ctx)
};
// This iterator walks through all keys and is needed to calculate size used by each key
@@ -2210,7 +2198,8 @@ impl Timeline {
.as_mut()
.unwrap()
.put_value(key, lsn, value, ctx)
.await?;
.await
.map_err(CompactionError::Other)?;
} else {
let owner = self.shard_identity.get_shard_number(&key);
@@ -2839,7 +2828,7 @@ impl Timeline {
Ok(())
}
/// Check to bail out of gc compaction early if it would use too much memory.
/// Check if the memory usage is within the limit.
async fn check_memory_usage(
self: &Arc<Self>,
layer_selection: &[Layer],
@@ -2852,8 +2841,7 @@ impl Timeline {
let layer_desc = layer.layer_desc();
if layer_desc.is_delta() {
// Delta layers at most have 1MB buffer; 3x to make it safe (there're deltas as large as 16KB).
// Scale it by target_layer_size_bytes so that tests can pass (some tests, e.g., `test_pageserver_gc_compaction_preempt
// use 3MB layer size and we need to account for that).
// Multiply the layer size so that tests can pass.
estimated_memory_usage_mb +=
3.0 * (layer_desc.file_size / target_layer_size_bytes) as f64;
num_delta_layers += 1;
@@ -3612,13 +3600,6 @@ impl Timeline {
last_key = Some(key);
}
accumulated_values.push((key, lsn, val));
if accumulated_values.len() >= 65536 {
// Assume all of them are images, that would be 512MB of data in memory for a single key.
return Err(CompactionError::Other(anyhow!(
"too many values for a single key, giving up gc-compaction"
)));
}
} else {
let last_key: &mut Key = last_key.as_mut().unwrap();
stat.on_unique_key_visited(); // TODO: adjust statistics for partial compaction

View File

@@ -149,7 +149,14 @@ pub async fn doit(
}
.await?;
flow::run(timeline.clone(), control_file, storage.clone(), ctx).await?;
flow::run(
timeline.clone(),
base_lsn,
control_file,
storage.clone(),
ctx,
)
.await?;
//
// Communicate that shard is done.

View File

@@ -34,9 +34,7 @@ use std::sync::Arc;
use anyhow::{bail, ensure};
use bytes::Bytes;
use futures::stream::FuturesOrdered;
use itertools::Itertools;
use pageserver_api::config::TimelineImportConfig;
use pageserver_api::key::{
CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, Key, TWOPHASEDIR_KEY, rel_block_to_key,
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
@@ -48,9 +46,8 @@ use pageserver_api::shard::ShardIdentity;
use postgres_ffi::relfile_utils::parse_relfilename;
use postgres_ffi::{BLCKSZ, pg_constants};
use remote_storage::RemotePath;
use tokio::sync::Semaphore;
use tokio_stream::StreamExt;
use tracing::{debug, instrument};
use tokio::task::JoinSet;
use tracing::{Instrument, debug, info_span, instrument};
use utils::bin_ser::BeSer;
use utils::lsn::Lsn;
@@ -66,40 +63,38 @@ use crate::tenant::storage_layer::{ImageLayerWriter, Layer};
pub async fn run(
timeline: Arc<Timeline>,
pgdata_lsn: Lsn,
control_file: ControlFile,
storage: RemoteStorageWrapper,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let planner = Planner {
Flow {
timeline,
pgdata_lsn,
control_file,
storage: storage.clone(),
shard: timeline.shard_identity,
tasks: Vec::default(),
};
let import_config = &timeline.conf.timeline_import_config;
let plan = planner.plan(import_config).await?;
plan.execute(timeline, import_config, ctx).await
tasks: Vec::new(),
storage,
}
.run(ctx)
.await
}
struct Planner {
struct Flow {
timeline: Arc<Timeline>,
pgdata_lsn: Lsn,
control_file: ControlFile,
storage: RemoteStorageWrapper,
shard: ShardIdentity,
tasks: Vec<AnyImportTask>,
storage: RemoteStorageWrapper,
}
struct Plan {
jobs: Vec<ChunkProcessingJob>,
}
impl Planner {
/// Creates an import plan
///
/// This function is and must remain pure: given the same input, it will generate the same import plan.
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
impl Flow {
/// Perform the ingestion into [`Self::timeline`].
/// Assumes the timeline is empty (= no layers).
pub async fn run(mut self, ctx: &RequestContext) -> anyhow::Result<()> {
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
self.pgdata_lsn = pgdata_lsn;
let datadir = PgDataDir::new(&self.storage).await?;
// Import dbdir (00:00:00 keyspace)
@@ -120,7 +115,7 @@ impl Planner {
}
// Import SLRUs
if self.shard.is_shard_zero() {
if self.timeline.tenant_shard_id.is_shard_zero() {
// pg_xact (01:00 keyspace)
self.import_slru(SlruKind::Clog, &self.storage.pgdata().join("pg_xact"))
.await?;
@@ -171,16 +166,14 @@ impl Planner {
let mut last_end_key = Key::MIN;
let mut current_chunk = Vec::new();
let mut current_chunk_size: usize = 0;
let mut jobs = Vec::new();
let mut parallel_jobs = Vec::new();
for task in std::mem::take(&mut self.tasks).into_iter() {
if current_chunk_size + task.total_size()
> import_config.import_job_soft_size_limit.into()
{
if current_chunk_size + task.total_size() > 1024 * 1024 * 1024 {
let key_range = last_end_key..task.key_range().start;
jobs.push(ChunkProcessingJob::new(
parallel_jobs.push(ChunkProcessingJob::new(
key_range.clone(),
std::mem::take(&mut current_chunk),
pgdata_lsn,
&self,
));
last_end_key = key_range.end;
current_chunk_size = 0;
@@ -188,13 +181,45 @@ impl Planner {
current_chunk_size += task.total_size();
current_chunk.push(task);
}
jobs.push(ChunkProcessingJob::new(
parallel_jobs.push(ChunkProcessingJob::new(
last_end_key..Key::MAX,
current_chunk,
pgdata_lsn,
&self,
));
Ok(Plan { jobs })
// Start all jobs simultaneosly
let mut work = JoinSet::new();
// TODO: semaphore?
for job in parallel_jobs {
let ctx: RequestContext =
ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Error);
work.spawn(async move { job.run(&ctx).await }.instrument(info_span!("parallel_job")));
}
let mut results = Vec::new();
while let Some(result) = work.join_next().await {
match result {
Ok(res) => {
results.push(res);
}
Err(_joinset_err) => {
results.push(Err(anyhow::anyhow!(
"parallel job panicked or cancelled, check pageserver logs"
)));
}
}
}
if results.iter().all(|r| r.is_ok()) {
Ok(())
} else {
let mut msg = String::new();
for result in results {
if let Err(err) = result {
msg.push_str(&format!("{err:?}\n\n"));
}
}
bail!("Some parallel jobs failed:\n\n{msg}");
}
}
#[instrument(level = tracing::Level::DEBUG, skip_all, fields(dboid=%db.dboid, tablespace=%db.spcnode, path=%db.path))]
@@ -241,7 +266,7 @@ impl Planner {
let end_key = rel_block_to_key(file.rel_tag, start_blk + (len / 8192) as u32);
self.tasks
.push(AnyImportTask::RelBlocks(ImportRelBlocksTask::new(
self.shard,
*self.timeline.get_shard_identity(),
start_key..end_key,
&file.path,
self.storage.clone(),
@@ -264,7 +289,7 @@ impl Planner {
}
async fn import_slru(&mut self, kind: SlruKind, path: &RemotePath) -> anyhow::Result<()> {
assert!(self.shard.is_shard_zero());
assert!(self.timeline.tenant_shard_id.is_shard_zero());
let segments = self.storage.listfilesindir(path).await?;
let segments: Vec<(String, u32, usize)> = segments
@@ -319,68 +344,6 @@ impl Planner {
}
}
impl Plan {
async fn execute(
self,
timeline: Arc<Timeline>,
import_config: &TimelineImportConfig,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut work = FuturesOrdered::new();
let semaphore = Arc::new(Semaphore::new(import_config.import_job_concurrency.into()));
let jobs_in_plan = self.jobs.len();
let mut jobs = self.jobs.into_iter().enumerate().peekable();
let mut results = Vec::new();
// Run import jobs concurrently up to the limit specified by the pageserver configuration.
// Note that we process completed futures in the oreder of insertion. This will be the
// building block for resuming imports across pageserver restarts or tenant migrations.
while results.len() < jobs_in_plan {
tokio::select! {
permit = semaphore.clone().acquire_owned(), if jobs.peek().is_some() => {
let permit = permit.expect("never closed");
let (job_idx, job) = jobs.next().expect("we peeked");
let job_timeline = timeline.clone();
let ctx = ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Error);
work.push_back(tokio::task::spawn(async move {
let _permit = permit;
let res = job.run(job_timeline, &ctx).await;
(job_idx, res)
}));
},
maybe_complete_job_idx = work.next() => {
match maybe_complete_job_idx {
Some(Ok((_job_idx, res))) => {
results.push(res);
},
Some(Err(_)) => {
results.push(Err(anyhow::anyhow!(
"parallel job panicked or cancelled, check pageserver logs"
)));
}
None => {}
}
}
}
}
if results.iter().all(|r| r.is_ok()) {
Ok(())
} else {
let mut msg = String::new();
for result in results {
if let Err(err) = result {
msg.push_str(&format!("{err:?}\n\n"));
}
}
bail!("Some parallel jobs failed:\n\n{msg}");
}
}
}
//
// dbdir iteration tools
//
@@ -750,6 +713,7 @@ impl From<ImportSlruBlocksTask> for AnyImportTask {
}
struct ChunkProcessingJob {
timeline: Arc<Timeline>,
range: Range<Key>,
tasks: Vec<AnyImportTask>,
@@ -757,24 +721,25 @@ struct ChunkProcessingJob {
}
impl ChunkProcessingJob {
fn new(range: Range<Key>, tasks: Vec<AnyImportTask>, pgdata_lsn: Lsn) -> Self {
assert!(pgdata_lsn.is_valid());
fn new(range: Range<Key>, tasks: Vec<AnyImportTask>, env: &Flow) -> Self {
assert!(env.pgdata_lsn.is_valid());
Self {
timeline: env.timeline.clone(),
range,
tasks,
pgdata_lsn,
pgdata_lsn: env.pgdata_lsn,
}
}
async fn run(self, timeline: Arc<Timeline>, ctx: &RequestContext) -> anyhow::Result<()> {
async fn run(self, ctx: &RequestContext) -> anyhow::Result<()> {
let mut writer = ImageLayerWriter::new(
timeline.conf,
timeline.timeline_id,
timeline.tenant_shard_id,
self.timeline.conf,
self.timeline.timeline_id,
self.timeline.tenant_shard_id,
&self.range,
self.pgdata_lsn,
&timeline.gate,
timeline.cancel.clone(),
&self.timeline.gate,
self.timeline.cancel.clone(),
ctx,
)
.await?;
@@ -786,20 +751,24 @@ impl ChunkProcessingJob {
let resident_layer = if nimages > 0 {
let (desc, path) = writer.finish(ctx).await?;
Layer::finish_creating(timeline.conf, &timeline, desc, &path)?
Layer::finish_creating(self.timeline.conf, &self.timeline, desc, &path)?
} else {
// dropping the writer cleans up
return Ok(());
};
// this is sharing the same code as create_image_layers
let mut guard = timeline.layers.write().await;
let mut guard = self.timeline.layers.write().await;
guard
.open_mut()?
.track_new_image_layers(&[resident_layer.clone()], &timeline.metrics);
.track_new_image_layers(&[resident_layer.clone()], &self.timeline.metrics);
crate::tenant::timeline::drop_wlock(guard);
timeline
// Schedule the layer for upload but don't add barriers such as
// wait for completion or index upload, so we don't inhibit upload parallelism.
// TODO: limit upload parallelism somehow (e.g. by limiting concurrency of jobs?)
// TODO: or regulate parallelism by upload queue depth? Prob should happen at a higher level.
self.timeline
.remote_client
.schedule_layer_file_upload(resident_layer)?;

View File

@@ -63,7 +63,6 @@ pub struct WalReceiver {
/// All task spawned by [`WalReceiver::start`] and its children are sensitive to this token.
/// It's a child token of [`Timeline`] so that timeline shutdown can cancel WalReceiver tasks early for `freeze_and_flush=true`.
cancel: CancellationToken,
task: tokio::task::JoinHandle<()>,
}
impl WalReceiver {
@@ -80,7 +79,7 @@ impl WalReceiver {
let loop_status = Arc::new(std::sync::RwLock::new(None));
let manager_status = Arc::clone(&loop_status);
let cancel = timeline.cancel.child_token();
let task = WALRECEIVER_RUNTIME.spawn({
WALRECEIVER_RUNTIME.spawn({
let cancel = cancel.clone();
async move {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -121,25 +120,14 @@ impl WalReceiver {
Self {
manager_status,
cancel,
task,
}
}
#[instrument(skip_all, level = tracing::Level::DEBUG)]
pub async fn shutdown(self) {
pub fn cancel(&self) {
debug_assert_current_span_has_tenant_and_timeline_id();
debug!("cancelling walreceiver tasks");
self.cancel.cancel();
match self.task.await {
Ok(()) => debug!("Shutdown success"),
Err(je) if je.is_cancelled() => unreachable!("not used"),
Err(je) if je.is_panic() => {
// already logged by panic hook
}
Err(je) => {
error!("shutdown walreceiver task join error: {je}")
}
}
}
pub(crate) fn status(&self) -> Option<ConnectionManagerStatus> {

View File

@@ -14,6 +14,8 @@
use std::fs::File;
use std::io::{Error, ErrorKind};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
@@ -74,8 +76,6 @@ pub struct VirtualFile {
impl VirtualFile {
/// Open a file in read-only mode. Like File::open.
///
/// Insensitive to `virtual_file_io_mode` setting.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
@@ -97,20 +97,36 @@ impl VirtualFile {
Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
/// `O_DIRECT` will be enabled base on `virtual_file_io_mode`.
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
mut open_options: OpenOptions,
open_options: &OpenOptions,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let mode = get_io_mode();
let direct = match (mode, open_options.is_write()) {
let set_o_direct = match (mode, open_options.is_write()) {
(IoMode::Buffered, _) => false,
#[cfg(target_os = "linux")]
(IoMode::Direct, false) => true,
#[cfg(target_os = "linux")]
(IoMode::Direct, true) => false,
#[cfg(target_os = "linux")]
(IoMode::DirectRw, _) => true,
};
open_options = open_options.direct(direct);
let open_options = open_options.clone();
let open_options = if set_o_direct {
#[cfg(target_os = "linux")]
{
let mut open_options = open_options;
open_options.custom_flags(nix::libc::O_DIRECT);
open_options
}
#[cfg(not(target_os = "linux"))]
unreachable!(
"O_DIRECT is not supported on this platform, IoMode's that result in set_o_direct=true shouldn't even be defined"
);
} else {
open_options
};
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile { inner, _mode: mode })
}
@@ -514,7 +530,7 @@ impl VirtualFileInner {
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true).clone(), ctx).await
}
/// Open a file with given options.
@@ -542,11 +558,10 @@ impl VirtualFileInner {
// It would perhaps be nicer to check just for the read and write flags
// explicitly, but OpenOptions doesn't contain any functions to read flags,
// only to set them.
let reopen_options = open_options
.clone()
.create(false)
.create_new(false)
.truncate(false);
let mut reopen_options = open_options.clone();
reopen_options.create(false);
reopen_options.create_new(false);
reopen_options.truncate(false);
let vfile = VirtualFileInner {
handle: RwLock::new(handle),
@@ -782,12 +797,6 @@ impl VirtualFileInner {
where
Buf: tokio_epoll_uring::IoBufMut + Send,
{
self.validate_direct_io(
Slice::stable_ptr(&buf).addr(),
Slice::bytes_total(&buf),
offset,
);
let file_guard = match self
.lock_file()
.await
@@ -813,8 +822,6 @@ impl VirtualFileInner {
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<B>, Result<usize, Error>) {
self.validate_direct_io(buf.as_ptr().addr(), buf.len(), offset);
let file_guard = match self.lock_file().await {
Ok(file_guard) => file_guard,
Err(e) => return (buf, Err(e)),
@@ -829,64 +836,6 @@ impl VirtualFileInner {
(buf, result)
})
}
/// Validate all reads and writes to adhere to the O_DIRECT requirements of our production systems.
///
/// Validating it iin userspace sets a consistent bar, independent of what actual OS/filesystem/block device is in use.
fn validate_direct_io(&self, addr: usize, size: usize, offset: u64) {
// TODO: eventually enable validation in the builds we use in real environments like staging, preprod, and prod.
if !(cfg!(feature = "testing") || cfg!(test)) {
return;
}
if !self.open_options.is_direct() {
return;
}
// Validate buffer memory alignment.
//
// What practically matters as of Linux 6.1 is bdev_dma_alignment()
// which is practically between 512 and 4096.
// On our production systems, the value is 512.
// The IoBuffer/IoBufferMut hard-code that value.
//
// Because the alloctor might return _more_ aligned addresses than requested,
// there is a chance that testing would not catch violations of a runtime requirement stricter than 512.
{
let requirement = 512;
let remainder = addr % requirement;
assert!(
remainder == 0,
"Direct I/O buffer must be aligned: buffer_addr=0x{addr:x} % 0x{requirement:x} = 0x{remainder:x}"
);
}
// Validate offset alignment.
//
// We hard-code 512 throughout the code base.
// So enforce just that and not anything more restrictive.
// Even the shallowest testing will expose more restrictive requirements if those ever arise.
{
let requirement = 512;
let remainder = offset % requirement;
assert!(
remainder == 0,
"Direct I/O offset must be aligned: offset=0x{offset:x} % 0x{requirement:x} = 0x{remainder:x}"
);
}
// Validate buffer size multiple requirement.
//
// The requirement in Linux 6.1 is bdev_logical_block_size().
// On our production systems, that is 512.
{
let requirement = 512;
let remainder = size % requirement;
assert!(
remainder == 0,
"Direct I/O buffer size must be a multiple of {requirement}: size=0x{size:x} % 0x{requirement:x} = 0x{remainder:x}"
);
}
}
}
// Adapted from https://doc.rust-lang.org/1.72.0/src/std/os/unix/fs.rs.html#117-135
@@ -1275,6 +1224,7 @@ mod tests {
use std::sync::Arc;
use owned_buffers_io::io_buf_ext::IoBufExt;
use owned_buffers_io::slice::SliceMutExt;
use rand::seq::SliceRandom;
use rand::{Rng, thread_rng};
@@ -1282,85 +1232,208 @@ mod tests {
use crate::context::DownloadBehavior;
use crate::task_mgr::TaskKind;
enum MaybeVirtualFile {
VirtualFile(VirtualFile),
File(File),
}
impl From<VirtualFile> for MaybeVirtualFile {
fn from(vf: VirtualFile) -> Self {
MaybeVirtualFile::VirtualFile(vf)
}
}
impl MaybeVirtualFile {
async fn read_exact_at(
&self,
mut slice: tokio_epoll_uring::Slice<IoBufferMut>,
offset: u64,
ctx: &RequestContext,
) -> Result<tokio_epoll_uring::Slice<IoBufferMut>, Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => file.read_exact_at(slice, offset, ctx).await,
MaybeVirtualFile::File(file) => {
let rust_slice: &mut [u8] = slice.as_mut_rust_slice_full_zeroed();
file.read_exact_at(rust_slice, offset).map(|()| slice)
}
}
}
async fn write_all_at<Buf: IoBufAligned + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> Result<(), Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => {
let (_buf, res) = file.write_all_at(buf, offset, ctx).await;
res
}
MaybeVirtualFile::File(file) => file.write_all_at(&buf[..], offset),
}
}
// Helper function to slurp a portion of a file into a string
async fn read_string_at(
&mut self,
pos: u64,
len: usize,
ctx: &RequestContext,
) -> Result<String, Error> {
let slice = IoBufferMut::with_capacity(len).slice_full();
assert_eq!(slice.bytes_total(), len);
let slice = self.read_exact_at(slice, pos, ctx).await?;
let buf = slice.into_inner();
assert_eq!(buf.len(), len);
Ok(String::from_utf8(buf.to_vec()).unwrap())
}
}
#[tokio::test]
async fn test_virtual_files() -> anyhow::Result<()> {
// The real work is done in the test_files() helper function. This
// allows us to run the same set of tests against a native File, and
// VirtualFile. We trust the native Files and wouldn't need to test them,
// but this allows us to verify that the operations return the same
// results with VirtualFiles as with native Files. (Except that with
// native files, you will run out of file descriptors if the ulimit
// is low enough.)
struct A;
impl Adapter for A {
async fn open(
path: Utf8PathBuf,
opts: OpenOptions,
ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error> {
let vf = VirtualFile::open_with_options_v2(&path, &opts, ctx).await?;
Ok(MaybeVirtualFile::VirtualFile(vf))
}
}
test_files::<A>("virtual_files").await
}
#[tokio::test]
async fn test_physical_files() -> anyhow::Result<()> {
struct B;
impl Adapter for B {
async fn open(
path: Utf8PathBuf,
opts: OpenOptions,
_ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error> {
Ok(MaybeVirtualFile::File({
let owned_fd = opts.open(path.as_std_path()).await?;
File::from(owned_fd)
}))
}
}
test_files::<B>("physical_files").await
}
/// This is essentially a closure which returns a MaybeVirtualFile, but because rust edition
/// 2024 is not yet out with new lifetime capture or outlives rules, this is a async function
/// in trait which benefits from the new lifetime capture rules already.
trait Adapter {
async fn open(
path: Utf8PathBuf,
opts: OpenOptions,
ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error>;
}
async fn test_files<A>(testname: &str) -> anyhow::Result<()>
where
A: Adapter,
{
let ctx =
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error).with_scope_unit_test();
let testdir = crate::config::PageServerConf::test_repo_dir("test_virtual_files");
let testdir = crate::config::PageServerConf::test_repo_dir(testname);
std::fs::create_dir_all(&testdir)?;
let zeropad512 = |content: &[u8]| {
let mut buf = IoBufferMut::with_capacity_zeroed(512);
buf[..content.len()].copy_from_slice(content);
buf.freeze().slice_len()
};
let path_a = testdir.join("file_a");
let file_a = VirtualFile::open_with_options_v2(
let mut file_a = A::open(
path_a.clone(),
OpenOptions::new()
.read(true)
.write(true)
// set create & truncate flags to ensure when we trigger a reopen later in this test,
// the reopen_options must have masked out those flags; if they don't, then
// the after reopen we will fail to read the `content_a` that we write here.
.create(true)
.truncate(true),
.truncate(true)
.to_owned(),
&ctx,
)
.await?;
let (_, res) = file_a.write_all_at(zeropad512(b"content_a"), 0, &ctx).await;
res?;
file_a
.write_all_at(IoBuffer::from(b"foobar").slice_len(), 0, &ctx)
.await?;
// cannot read from a file opened in write-only mode
let _ = file_a.read_string_at(0, 1, &ctx).await.unwrap_err();
// Close the file and re-open for reading
let mut file_a = A::open(path_a, OpenOptions::new().read(true).to_owned(), &ctx).await?;
// cannot write to a file opened in read-only mode
let _ = file_a
.write_all_at(IoBuffer::from(b"bar").slice_len(), 0, &ctx)
.await
.unwrap_err();
// Try simple read
assert_eq!("foobar", file_a.read_string_at(0, 6, &ctx).await?);
// Create another test file, and try FileExt functions on it.
let path_b = testdir.join("file_b");
let file_b = VirtualFile::open_with_options_v2(
let mut file_b = A::open(
path_b.clone(),
OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true),
.truncate(true)
.to_owned(),
&ctx,
)
.await?;
let (_, res) = file_b.write_all_at(zeropad512(b"content_b"), 0, &ctx).await;
res?;
file_b
.write_all_at(IoBuffer::from(b"BAR").slice_len(), 3, &ctx)
.await?;
file_b
.write_all_at(IoBuffer::from(b"FOO").slice_len(), 0, &ctx)
.await?;
let assert_first_512_eq = async |vfile: &VirtualFile, expect: &[u8]| {
let buf = vfile
.read_exact_at(IoBufferMut::with_capacity_zeroed(512).slice_full(), 0, &ctx)
.await
.unwrap();
assert_eq!(&buf[..], &zeropad512(expect)[..]);
};
assert_eq!(file_b.read_string_at(2, 3, &ctx).await?, "OBA");
// Open a lot of file descriptors / VirtualFile instances.
// Enough to cause some evictions in the fd cache.
// Open a lot of files, enough to cause some evictions. (Or to be precise,
// open the same file many times. The effect is the same.)
let mut file_b_dupes = Vec::new();
let mut vfiles = Vec::new();
for _ in 0..100 {
let vfile = VirtualFile::open_with_options_v2(
let mut vfile = A::open(
path_b.clone(),
OpenOptions::new().read(true),
OpenOptions::new().read(true).to_owned(),
&ctx,
)
.await?;
assert_first_512_eq(&vfile, b"content_b").await;
file_b_dupes.push(vfile);
assert_eq!("FOOBAR", vfile.read_string_at(0, 6, &ctx).await?);
vfiles.push(vfile);
}
// make sure we opened enough files to definitely cause evictions.
assert!(file_b_dupes.len() > TEST_MAX_FILE_DESCRIPTORS * 2);
assert!(vfiles.len() > TEST_MAX_FILE_DESCRIPTORS * 2);
// The underlying file descriptor for 'file_a' should be closed now. Try to read
// from it again. The VirtualFile reopens the file internally.
assert_first_512_eq(&file_a, b"content_a").await;
// from it again.
assert_eq!("foobar", file_a.read_string_at(0, 6, &ctx).await?);
// Check that all the other FDs still work too. Use them in random order for
// good measure.
file_b_dupes.as_mut_slice().shuffle(&mut thread_rng());
for vfile in file_b_dupes.iter_mut() {
assert_first_512_eq(vfile, b"content_b").await;
vfiles.as_mut_slice().shuffle(&mut thread_rng());
for vfile in vfiles.iter_mut() {
assert_eq!("OOBAR", vfile.read_string_at(1, 5, &ctx).await?);
}
Ok(())
@@ -1391,9 +1464,9 @@ mod tests {
// Open the file many times.
let mut files = Vec::new();
for _ in 0..VIRTUAL_FILES {
let f = VirtualFile::open_with_options_v2(
let f = VirtualFileInner::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
OpenOptions::new().read(true).clone(),
&ctx,
)
.await?;
@@ -1436,6 +1509,8 @@ mod tests {
#[tokio::test]
async fn test_atomic_overwrite_basic() {
let ctx =
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error).with_scope_unit_test();
let testdir = crate::config::PageServerConf::test_repo_dir("test_atomic_overwrite_basic");
std::fs::create_dir_all(&testdir).unwrap();
@@ -1445,22 +1520,26 @@ mod tests {
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
.await
.unwrap();
let post = std::fs::read_to_string(&path).unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "foo");
assert!(!tmp_path.exists());
drop(file);
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
.await
.unwrap();
let post = std::fs::read_to_string(&path).unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "bar");
assert!(!tmp_path.exists());
drop(file);
}
#[tokio::test]
async fn test_atomic_overwrite_preexisting_tmp() {
let ctx =
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error).with_scope_unit_test();
let testdir =
crate::config::PageServerConf::test_repo_dir("test_atomic_overwrite_preexisting_tmp");
std::fs::create_dir_all(&testdir).unwrap();
@@ -1475,8 +1554,10 @@ mod tests {
.await
.unwrap();
let post = std::fs::read_to_string(&path).unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "foo");
assert!(!tmp_path.exists());
drop(file);
}
}

View File

@@ -111,17 +111,13 @@ pub(crate) fn get() -> IoEngine {
use std::os::unix::prelude::FileExt;
use std::sync::atomic::{AtomicU8, Ordering};
#[cfg(target_os = "linux")]
use {std::time::Duration, tracing::info};
use super::owned_buffers_io::io_buf_ext::FullSlice;
use super::owned_buffers_io::slice::SliceMutExt;
use super::{FileGuard, Metadata};
#[cfg(target_os = "linux")]
pub(super) fn epoll_uring_error_to_std(
e: tokio_epoll_uring::Error<std::io::Error>,
) -> std::io::Error {
fn epoll_uring_error_to_std(e: tokio_epoll_uring::Error<std::io::Error>) -> std::io::Error {
match e {
tokio_epoll_uring::Error::Op(e) => e,
tokio_epoll_uring::Error::System(system) => {
@@ -153,11 +149,7 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring_ext::thread_local_system().await;
let (resources, res) =
retry_ecanceled_once((file_guard, slice), |(file_guard, slice)| async {
system.read(file_guard, offset, slice).await
})
.await;
let (resources, res) = system.read(file_guard, offset, slice).await;
(resources, res.map_err(epoll_uring_error_to_std))
}
}
@@ -172,10 +164,7 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring_ext::thread_local_system().await;
let (resources, res) = retry_ecanceled_once(file_guard, |file_guard| async {
system.fsync(file_guard).await
})
.await;
let (resources, res) = system.fsync(file_guard).await;
(resources, res.map_err(epoll_uring_error_to_std))
}
}
@@ -193,10 +182,7 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring_ext::thread_local_system().await;
let (resources, res) = retry_ecanceled_once(file_guard, |file_guard| async {
system.fdatasync(file_guard).await
})
.await;
let (resources, res) = system.fdatasync(file_guard).await;
(resources, res.map_err(epoll_uring_error_to_std))
}
}
@@ -215,10 +201,7 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring_ext::thread_local_system().await;
let (resources, res) = retry_ecanceled_once(file_guard, |file_guard| async {
system.statx(file_guard).await
})
.await;
let (resources, res) = system.statx(file_guard).await;
(
resources,
res.map_err(epoll_uring_error_to_std).map(Metadata::from),
@@ -241,7 +224,6 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
// TODO: ftruncate op for tokio-epoll-uring
// Don't forget to use retry_ecanceled_once
let res = file_guard.with_std_file(|std_file| std_file.set_len(len));
(file_guard, res)
}
@@ -263,11 +245,8 @@ impl IoEngine {
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
let system = tokio_epoll_uring_ext::thread_local_system().await;
let ((file_guard, slice), res) = retry_ecanceled_once(
(file_guard, buf.into_raw_slice()),
async |(file_guard, buf)| system.write(file_guard, offset, buf).await,
)
.await;
let ((file_guard, slice), res) =
system.write(file_guard, offset, buf.into_raw_slice()).await;
(
(file_guard, FullSlice::must_new(slice)),
res.map_err(epoll_uring_error_to_std),
@@ -303,56 +282,6 @@ impl IoEngine {
}
}
/// We observe in tests that stop pageserver with SIGTERM immediately after it was ingesting data,
/// occasionally buffered writers fail (and get retried by BufferedWriter) with ECANCELED.
/// The problem is believed to be a race condition in how io_uring handles punted async work (io-wq) and signals.
/// Investigation ticket: <https://github.com/neondatabase/neon/issues/11446>
///
/// This function retries the operation once if it fails with ECANCELED.
/// ONLY USE FOR IDEMPOTENT [`super::VirtualFile`] operations.
#[cfg(target_os = "linux")]
pub(super) async fn retry_ecanceled_once<F, Fut, T, V>(
resources: T,
f: F,
) -> (T, Result<V, tokio_epoll_uring::Error<std::io::Error>>)
where
F: Fn(T) -> Fut,
Fut: std::future::Future<Output = (T, Result<V, tokio_epoll_uring::Error<std::io::Error>>)>,
T: Send,
V: Send,
{
let (resources, res) = f(resources).await;
let Err(e) = res else {
return (resources, res);
};
let tokio_epoll_uring::Error::Op(err) = e else {
return (resources, Err(e));
};
if err.raw_os_error() != Some(nix::libc::ECANCELED) {
return (resources, Err(tokio_epoll_uring::Error::Op(err)));
}
{
static RATE_LIMIT: std::sync::Mutex<utils::rate_limit::RateLimit> =
std::sync::Mutex::new(utils::rate_limit::RateLimit::new(Duration::from_secs(1)));
let mut guard = RATE_LIMIT.lock().unwrap();
guard.call2(|rate_limit_stats| {
info!(
%rate_limit_stats, "ECANCELED observed, assuming it is due to a signal being received by the submitting thread, retrying after a delay; this message is rate-limited"
);
});
drop(guard);
}
tokio::time::sleep(Duration::from_millis(100)).await; // something big enough to beat even heavily overcommitted CI runners
let (resources, res) = f(resources).await;
(resources, res)
}
pub(super) fn panic_operation_must_be_idempotent() {
panic!(
"unsupported; io_engine may retry operations internally and thus needs them to be idempotent (retry_ecanceled_once)"
)
}
pub enum FeatureTestResult {
PlatformPreferred(IoEngineKind),
Worse {

View File

@@ -1,20 +1,13 @@
//! Enum-dispatch to the `OpenOptions` type of the respective [`super::IoEngineKind`];
use std::os::fd::OwnedFd;
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use super::io_engine::IoEngine;
#[derive(Debug, Clone)]
pub struct OpenOptions {
/// We keep a copy of the write() flag we pass to the `inner`` `OptionOptions`
/// to support [`Self::is_write`].
write: bool,
/// We don't expose + pass through a raw `custom_flags()` style API.
/// The only custom flag we support is `O_DIRECT`, which we track here
/// and map to `custom_flags()` in the [`Self::open`] method.
direct: bool,
inner: Inner,
}
#[derive(Debug, Clone)]
@@ -36,7 +29,6 @@ impl Default for OpenOptions {
};
Self {
write: false,
direct: false,
inner,
}
}
@@ -51,11 +43,7 @@ impl OpenOptions {
self.write
}
pub(super) fn is_direct(&self) -> bool {
self.direct
}
pub fn read(mut self, read: bool) -> Self {
pub fn read(&mut self, read: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.read(read);
@@ -68,7 +56,7 @@ impl OpenOptions {
self
}
pub fn write(mut self, write: bool) -> Self {
pub fn write(&mut self, write: bool) -> &mut OpenOptions {
self.write = write;
match &mut self.inner {
Inner::StdFs(x) => {
@@ -82,7 +70,7 @@ impl OpenOptions {
self
}
pub fn create(mut self, create: bool) -> Self {
pub fn create(&mut self, create: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create(create);
@@ -95,7 +83,7 @@ impl OpenOptions {
self
}
pub fn create_new(mut self, create_new: bool) -> Self {
pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create_new(create_new);
@@ -108,7 +96,7 @@ impl OpenOptions {
self
}
pub fn truncate(mut self, truncate: bool) -> Self {
pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.truncate(truncate);
@@ -121,53 +109,25 @@ impl OpenOptions {
self
}
/// Don't use, `O_APPEND` is not supported.
pub fn append(&mut self, _append: bool) {
super::io_engine::panic_operation_must_be_idempotent();
}
pub(in crate::virtual_file) async fn open(&self, path: &Path) -> std::io::Result<OwnedFd> {
#[cfg_attr(not(target_os = "linux"), allow(unused_mut))]
let mut custom_flags = 0;
if self.direct {
match &self.inner {
Inner::StdFs(x) => x.open(path).map(|file| file.into()),
#[cfg(target_os = "linux")]
{
custom_flags |= nix::libc::O_DIRECT;
}
#[cfg(not(target_os = "linux"))]
{
// Other platforms may be used for development but don't necessarily have a 1:1 equivalent to Linux's O_DIRECT (macOS!).
// Just don't set the flag; to catch alignment bugs typical for O_DIRECT,
// we have a runtime validation layer inside `VirtualFile::write_at` and `VirtualFile::read_at`.
static WARNING: std::sync::Once = std::sync::Once::new();
WARNING.call_once(|| {
let span = tracing::info_span!(parent: None, "open_options");
let _enter = span.enter();
tracing::warn!("your platform is not a supported production platform, ignoing request for O_DIRECT; this could hide alignment bugs; this warning is logged once per process");
});
}
}
match self.inner.clone() {
Inner::StdFs(mut x) => x
.custom_flags(custom_flags)
.open(path)
.map(|file| file.into()),
#[cfg(target_os = "linux")]
Inner::TokioEpollUring(mut x) => {
x.custom_flags(custom_flags);
Inner::TokioEpollUring(x) => {
let system = super::io_engine::tokio_epoll_uring_ext::thread_local_system().await;
let (_, res) = super::io_engine::retry_ecanceled_once((), |()| async {
let res = system.open(path, &x).await;
((), res)
system.open(path, x).await.map_err(|e| match e {
tokio_epoll_uring::Error::Op(e) => e,
tokio_epoll_uring::Error::System(system) => {
std::io::Error::new(std::io::ErrorKind::Other, system)
}
})
.await;
res.map_err(super::io_engine::epoll_uring_error_to_std)
}
}
}
}
pub fn mode(mut self, mode: u32) -> Self {
impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
fn mode(&mut self, mode: u32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.mode(mode);
@@ -180,8 +140,16 @@ impl OpenOptions {
self
}
pub fn direct(mut self, direct: bool) -> Self {
self.direct = direct;
fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.custom_flags(flags);
}
#[cfg(target_os = "linux")]
Inner::TokioEpollUring(x) => {
let _ = x.custom_flags(flags);
}
}
self
}
}

View File

@@ -247,19 +247,6 @@ pub enum FlushTaskError {
Cancelled,
}
impl FlushTaskError {
pub fn is_cancel(&self) -> bool {
match self {
FlushTaskError::Cancelled => true,
}
}
pub fn into_anyhow(self) -> anyhow::Error {
match self {
FlushTaskError::Cancelled => anyhow::anyhow!(self),
}
}
}
impl<Buf, W> FlushBackgroundTask<Buf, W>
where
Buf: IoBufAligned + Send + Sync,

View File

@@ -425,12 +425,15 @@ compact_prefetch_buffers(void)
* point inside and outside PostgreSQL.
*
* This still does throw errors when it receives malformed responses from PS.
*
* When we're not called from CHECK_FOR_INTERRUPTS (indicated by
* IsHandlingInterrupts) we also report we've ended prefetch receive work,
* just in case state tracking was lost due to an error in the sync getPage
* response code.
*/
void
communicator_prefetch_pump_state(void)
communicator_prefetch_pump_state(bool IsHandlingInterrupts)
{
START_PREFETCH_RECEIVE_WORK();
while (MyPState->ring_receive != MyPState->ring_flush)
{
NeonResponse *response;
@@ -479,7 +482,9 @@ communicator_prefetch_pump_state(void)
}
}
END_PREFETCH_RECEIVE_WORK();
/* We never pump the prefetch state while handling other pages */
if (!IsHandlingInterrupts)
END_PREFETCH_RECEIVE_WORK();
communicator_reconfigure_timeout_if_needed();
}
@@ -667,10 +672,9 @@ prefetch_wait_for(uint64 ring_index)
Assert(MyPState->ring_unused > ring_index);
START_PREFETCH_RECEIVE_WORK();
while (MyPState->ring_receive <= ring_index)
{
START_PREFETCH_RECEIVE_WORK();
entry = GetPrfSlot(MyPState->ring_receive);
Assert(entry->status == PRFS_REQUESTED);
@@ -679,18 +683,17 @@ prefetch_wait_for(uint64 ring_index)
result = false;
break;
}
END_PREFETCH_RECEIVE_WORK();
CHECK_FOR_INTERRUPTS();
}
if (result)
{
/* Check that slot is actually received (srver can be disconnected in prefetch_pump_state called from CHECK_FOR_INTERRUPTS */
PrefetchRequest *slot = GetPrfSlot(ring_index);
result = slot->status == PRFS_RECEIVED;
return slot->status == PRFS_RECEIVED;
}
END_PREFETCH_RECEIVE_WORK();
return result;
return false;
;
}
@@ -717,7 +720,6 @@ prefetch_read(PrefetchRequest *slot)
Assert(slot->status == PRFS_REQUESTED);
Assert(slot->response == NULL);
Assert(slot->my_ring_index == MyPState->ring_receive);
Assert(readpage_reentrant_guard);
if (slot->status != PRFS_REQUESTED ||
slot->response != NULL ||
@@ -800,7 +802,6 @@ communicator_prefetch_receive(BufferTag tag)
PrfHashEntry *entry;
PrefetchRequest hashkey;
Assert(readpage_reentrant_guard);
hashkey.buftag = tag;
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
if (entry != NULL && prefetch_wait_for(entry->slot->my_ring_index))
@@ -820,12 +821,8 @@ communicator_prefetch_receive(BufferTag tag)
void
prefetch_on_ps_disconnect(void)
{
bool save_readpage_reentrant_guard = readpage_reentrant_guard;
MyPState->ring_flush = MyPState->ring_unused;
/* Prohibit callig of prefetch_pump_state */
START_PREFETCH_RECEIVE_WORK();
while (MyPState->ring_receive < MyPState->ring_unused)
{
PrefetchRequest *slot;
@@ -854,9 +851,6 @@ prefetch_on_ps_disconnect(void)
MyNeonCounters->getpage_prefetch_discards_total += 1;
}
/* Restore guard */
readpage_reentrant_guard = save_readpage_reentrant_guard;
/*
* We can have gone into retry due to network error, so update stats with
* the latest available
@@ -2515,7 +2509,7 @@ communicator_processinterrupts(void)
if (timeout_signaled)
{
if (!readpage_reentrant_guard && readahead_getpage_pull_timeout_ms > 0)
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(true);
timeout_signaled = false;
communicator_reconfigure_timeout_if_needed();

View File

@@ -44,7 +44,7 @@ extern int communicator_read_slru_segment(SlruKind kind, int64 segno,
void *buffer);
extern void communicator_reconfigure_timeout_if_needed(void);
extern void communicator_prefetch_pump_state(void);
extern void communicator_prefetch_pump_state(bool IsHandlingInterrupts);
#endif

View File

@@ -433,6 +433,7 @@ pageserver_connect(shardno_t shard_no, int elevel)
now = GetCurrentTimestamp();
us_since_last_attempt = (int64) (now - shard->last_reconnect_time);
shard->last_reconnect_time = now;
/*
* Make sure we don't do exponential backoff with a constant multiplier
@@ -446,23 +447,14 @@ pageserver_connect(shardno_t shard_no, int elevel)
/*
* If we did other tasks between reconnect attempts, then we won't
* need to wait as long as a full delay.
*
* This is a loop to protect against interrupted sleeps.
*/
while (us_since_last_attempt < shard->delay_us)
if (us_since_last_attempt < shard->delay_us)
{
pg_usleep(shard->delay_us - us_since_last_attempt);
/* At least we should handle cancellations here */
CHECK_FOR_INTERRUPTS();
now = GetCurrentTimestamp();
us_since_last_attempt = (int64) (now - shard->last_reconnect_time);
}
/* update the delay metric */
shard->delay_us = Min(shard->delay_us * 2, MAX_RECONNECT_INTERVAL_USEC);
shard->last_reconnect_time = now;
/*
* Connect using the connection string we got from the

View File

@@ -150,7 +150,7 @@ NeonWALReaderFree(NeonWALReader *state)
* fetched from timeline 'tli'.
*
* Returns NEON_WALREAD_SUCCESS if succeeded, NEON_WALREAD_ERROR if an error
* occurs, in which case 'err' has the description. Error always closes remote
* occurs, in which case 'err' has the desciption. Error always closes remote
* connection, if there was any, so socket subscription should be removed.
*
* NEON_WALREAD_WOULDBLOCK means caller should obtain socket to wait for with

View File

@@ -1179,7 +1179,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
blocknum += iterblocks;
}
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
return false;
}
@@ -1218,7 +1218,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
communicator_prefetch_register_bufferv(tag, NULL, 1, NULL);
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
return false;
}
@@ -1262,7 +1262,7 @@ neon_writeback(SMgrRelation reln, ForkNumber forknum,
*/
neon_log(SmgrTrace, "writeback noop");
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
@@ -1315,7 +1315,7 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
}
/* Try to read PS results if they are available */
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
neon_get_request_lsns(InfoFromSMgrRel(reln), forkNum, blkno, &request_lsns, 1);
@@ -1339,7 +1339,7 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
/*
* Try to receive prefetch results once again just to make sure we don't leave the smgr code while the OS might still have buffered bytes.
*/
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (forkNum == MAIN_FORKNUM && IS_LOCAL_REL(reln))
@@ -1449,7 +1449,7 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
nblocks, PG_IOV_MAX);
/* Try to read PS results if they are available */
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
neon_get_request_lsns(InfoFromSMgrRel(reln), forknum, blocknum,
request_lsns, nblocks);
@@ -1480,7 +1480,7 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
/*
* Try to receive prefetch results once again just to make sure we don't leave the smgr code while the OS might still have buffered bytes.
*/
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (forknum == MAIN_FORKNUM && IS_LOCAL_REL(reln))
@@ -1665,7 +1665,7 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
lfc_write(InfoFromSMgrRel(reln), forknum, blocknum, buffer);
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
@@ -1727,7 +1727,7 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
lfc_writev(InfoFromSMgrRel(reln), forknum, blkno, buffers, nblocks);
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
@@ -1902,7 +1902,7 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
neon_log(SmgrTrace, "[NEON_SMGR] immedsync noop");
communicator_prefetch_pump_state();
communicator_prefetch_pump_state(false);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
@@ -1989,14 +1989,8 @@ neon_start_unlogged_build(SMgrRelation reln)
neon_log(ERROR, "unknown relpersistence '%c'", reln->smgr_relpersistence);
}
#if PG_MAJORVERSION_NUM >= 17
/*
* We have to disable this check for pg14-16 because sorted build of GIST index requires
* to perform unlogged build several times
*/
if (smgrnblocks(reln, MAIN_FORKNUM) != 0)
neon_log(ERROR, "cannot perform unlogged index build, index is not empty ");
#endif
unlogged_build_rel = reln;
unlogged_build_phase = UNLOGGED_BUILD_PHASE_1;

View File

@@ -124,7 +124,6 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
}
else
{
wp->safekeepers_generation = INVALID_GENERATION;
host = wp->config->safekeepers_list;
}
wp_log(LOG, "safekeepers_generation=%u", wp->safekeepers_generation);
@@ -757,7 +756,7 @@ UpdateMemberSafekeeperPtr(WalProposer *wp, Safekeeper *sk)
{
SafekeeperId *sk_id = &wp->mconf.members.m[i];
if (sk_id->node_id == sk->greetResponse.nodeId)
if (wp->mconf.members.m[i].node_id == sk->greetResponse.nodeId)
{
/*
* If mconf or list of safekeepers to connect to changed (the
@@ -782,7 +781,7 @@ UpdateMemberSafekeeperPtr(WalProposer *wp, Safekeeper *sk)
{
SafekeeperId *sk_id = &wp->mconf.new_members.m[i];
if (sk_id->node_id == sk->greetResponse.nodeId)
if (wp->mconf.new_members.m[i].node_id == sk->greetResponse.nodeId)
{
if (wp->new_members_safekeepers[i] != NULL && wp->new_members_safekeepers[i] != sk)
{
@@ -1072,6 +1071,7 @@ RecvVoteResponse(Safekeeper *sk)
/* ready for elected message */
sk->state = SS_WAIT_ELECTED;
wp->n_votes++;
/* Are we already elected? */
if (wp->state == WPS_CAMPAIGN)
{

View File

@@ -845,6 +845,9 @@ typedef struct WalProposer
/* timeline globally starts at this LSN */
XLogRecPtr timelineStartLsn;
/* number of votes collected from safekeepers */
int n_votes;
/* number of successful connections over the lifetime of walproposer */
int n_connected;

View File

@@ -207,8 +207,15 @@ async fn authenticate(
ctx.set_project(db_info.aux.clone());
info!("woken up a compute node");
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());

View File

@@ -409,22 +409,14 @@ impl JwkCacheEntryLock {
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired));
}
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
JwtClaimsError::JwtTokenNotYetReadyToUse,
));
}
}
@@ -542,10 +534,10 @@ struct JwtPayload<'a> {
#[serde(rename = "aud", default)]
audience: OneOrMany,
/// Expiration - Time after which the JWT expires
#[serde(rename = "exp", deserialize_with = "numeric_date_opt", default)]
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time before which the JWT is not valid
#[serde(rename = "nbf", deserialize_with = "numeric_date_opt", default)]
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
@@ -617,15 +609,8 @@ impl<'de> Deserialize<'de> for OneOrMany {
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
<Option<u64>>::deserialize(d)?
.map(|t| {
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_secs(t))
.ok_or_else(|| {
serde::de::Error::custom(format_args!("timestamp out of bounds: {t}"))
})
})
.transpose()
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
}
struct JwkRenewalPermit<'a> {
@@ -761,11 +746,11 @@ pub enum JwtClaimsError {
#[error("invalid JWT token audience")]
InvalidJwtTokenAudience,
#[error("JWT token has expired (exp={0})")]
JwtTokenHasExpired(u64),
#[error("JWT token has expired")]
JwtTokenHasExpired,
#[error("JWT token is not yet ready to use (nbf={0})")]
JwtTokenNotYetReadyToUse(u64),
#[error("JWT token is not yet ready to use")]
JwtTokenNotYetReadyToUse,
}
#[allow(dead_code, reason = "Debug use only")]
@@ -1248,14 +1233,14 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
"nbf": now + 60,
"aud": "neon",
}},
error: JwtClaimsError::JwtTokenNotYetReadyToUse(now + 60),
error: JwtClaimsError::JwtTokenNotYetReadyToUse,
},
Test {
body: json! {{
"exp": now - 60,
"aud": ["neon"],
}},
error: JwtClaimsError::JwtTokenHasExpired(now - 60),
error: JwtClaimsError::JwtTokenHasExpired,
},
Test {
body: json! {{

View File

@@ -12,9 +12,9 @@ use tracing::{debug, warn};
use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::metrics::{Metrics, SniKind};
use crate::proxy::NeonOptions;
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
use crate::serverless::SERVERLESS_DRIVER_SNI;
use crate::types::{EndpointId, RoleName};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
@@ -32,6 +32,12 @@ pub(crate) enum ComputeUserInfoParseError {
option: EndpointId,
},
#[error(
"Common name inferred from SNI ('{}') is not known",
.cn,
)]
UnknownCommonName { cn: String },
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(EndpointId),
}
@@ -60,15 +66,22 @@ impl ComputeUserInfoMaybeEndpoint {
}
}
pub(crate) fn endpoint_sni(sni: &str, common_names: &HashSet<String>) -> Option<EndpointId> {
let (subdomain, common_name) = sni.split_once('.')?;
pub(crate) fn endpoint_sni(
sni: &str,
common_names: &HashSet<String>,
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
let Some((subdomain, common_name)) = sni.split_once('.') else {
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
};
if !common_names.contains(common_name) {
return None;
return Err(ComputeUserInfoParseError::UnknownCommonName {
cn: common_name.into(),
});
}
if subdomain == SERVERLESS_DRIVER_SNI || subdomain == AUTH_BROKER_SNI {
return None;
if subdomain == SERVERLESS_DRIVER_SNI {
return Ok(None);
}
Some(EndpointId::from(subdomain))
Ok(Some(EndpointId::from(subdomain)))
}
impl ComputeUserInfoMaybeEndpoint {
@@ -100,8 +113,15 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain =
sni.and_then(|sni_str| common_names.and_then(|cn| endpoint_sni(sni_str, cn)));
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
@@ -128,23 +148,22 @@ impl ComputeUserInfoMaybeEndpoint {
let metrics = Metrics::get();
debug!(%user, "credentials");
let protocol = ctx.protocol();
let kind = if sni.is_some() {
if sni.is_some() {
debug!("Connection with sni");
SniKind::Sni
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
} else if endpoint.is_some() {
metrics
.proxy
.accepted_connections_by_sni
.inc(SniKind::NoSni);
debug!("Connection without sni");
SniKind::NoSni
} else {
metrics
.proxy
.accepted_connections_by_sni
.inc(SniKind::PasswordHack);
debug!("Connection with password hack");
SniKind::PasswordHack
};
metrics
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
let options = NeonOptions::parse_params(params);
@@ -405,34 +424,21 @@ mod tests {
}
#[test]
fn parse_unknown_sni() {
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert!(info.endpoint_id.is_none());
}
#[test]
fn parse_unknown_sni_with_options() {
let options = StartupMessageParams::new([
("user", "john_doe"),
("options", "endpoint=foo-bar-baz-1234"),
]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestContext::test();
let info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.unwrap();
assert_eq!(info.endpoint_id.as_deref(), Some("foo-bar-baz-1234"));
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]

View File

@@ -423,8 +423,8 @@ async fn refresh_config_inner(
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
&tls_config.key_path,
&tls_config.cert_path,
None,
false,
)

View File

@@ -1,10 +1,8 @@
//! A stand-alone program that routes connections, e.g. from
//! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
//!
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::path::Path;
/// A stand-alone program that routes connections, e.g. from
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
///
/// This allows connecting to pods/services running in the same Kubernetes cluster from
/// the outside. Similar to an ingress controller for HTTPS.
use std::{net::SocketAddr, sync::Arc};
use anyhow::{Context, anyhow, bail, ensure};
@@ -88,7 +86,46 @@ pub async fn run() -> anyhow::Result<()> {
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
(Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
(Some(key_path), Some(cert_path)) => {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys =
rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
PrivateKeyDer::Pkcs8(
keys.pop()
.expect("keys should not be empty")
.context(format!("Failed to read TLS keys at '{key_path}'"))?,
)
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain: Vec<_> = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -151,58 +188,7 @@ pub async fn run() -> anyhow::Result<()> {
match signal {}
}
pub(super) fn parse_tls(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
PrivateKeyDer::Pkcs8(
keys.pop()
.expect("keys should not be empty")
.context(format!(
"Failed to read TLS keys at '{}'",
key_path.display()
))?,
)
};
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
"Failed to read TLS cert file at '{}.'",
cert_path.display()
))?;
let cert_chain: Vec<_> = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{}'.",
cert_path.display()
)
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.context("ring should support TLS1.2 and TLS1.3")?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
Ok((tls_config, tls_server_end_point))
}
pub(super) async fn task_main(
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,

View File

@@ -1,10 +1,9 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{bail, ensure};
use anyhow::bail;
use arc_swap::ArcSwapOption;
use futures::future::Either;
use remote_storage::RemoteStorageConfig;
@@ -63,18 +62,18 @@ struct ProxyCliArgs {
region: String,
/// listen for incoming client connections on ip:port
#[clap(short, long, default_value = "127.0.0.1:4432")]
proxy: SocketAddr,
proxy: String,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
mgmt: String,
/// listen for incoming http connections (metrics, etc) on ip:port
#[clap(long, default_value = "127.0.0.1:7001")]
http: SocketAddr,
http: String,
/// listen for incoming wss connections on ip:port
#[clap(long)]
wss: Option<SocketAddr>,
wss: Option<String>,
/// redirect unauthenticated users to the given uri in case of console redirect auth
#[clap(short, long, default_value = "http://localhost:3000/psql_session/")]
uri: String,
@@ -100,18 +99,18 @@ struct ProxyCliArgs {
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
#[clap(short = 'k', long, alias = "ssl-key")]
tls_key: Option<PathBuf>,
tls_key: Option<String>,
/// path to TLS cert for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
#[clap(short = 'c', long, alias = "ssl-cert")]
tls_cert: Option<PathBuf>,
tls_cert: Option<String>,
/// Allow writing TLS session keys to the given file pointed to by the environment variable `SSLKEYLOGFILE`.
#[clap(long, alias = "allow-ssl-keylogfile")]
allow_tls_keylogfile: bool,
/// path to directory with TLS certificates for client postgres connections
#[clap(long)]
certs_dir: Option<PathBuf>,
certs_dir: Option<String>,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
@@ -230,9 +229,6 @@ struct ProxyCliArgs {
// TODO: rename to `console_redirect_confirmation_timeout`.
#[clap(long, default_value = "2m", value_parser = humantime::parse_duration)]
webauth_confirmation_timeout: std::time::Duration,
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -281,25 +277,6 @@ struct SqlOverHttpArgs {
sql_over_http_max_response_size_bytes: usize,
}
#[derive(clap::Args, Clone, Debug)]
struct PgSniRouterArgs {
/// listen for incoming client connections on ip:port
#[clap(id = "sni-router-listen", long, default_value = "127.0.0.1:4432")]
listen: SocketAddr,
/// listen for incoming client connections on ip:port, requiring TLS to compute
#[clap(id = "sni-router-listen-tls", long, default_value = "127.0.0.1:4433")]
listen_tls: SocketAddr,
/// path to TLS key for client postgres connections
#[clap(id = "sni-router-tls-key", long)]
tls_key: Option<PathBuf>,
/// path to TLS cert for client postgres connections
#[clap(id = "sni-router-tls-cert", long)]
tls_cert: Option<PathBuf>,
/// append this domain zone to the SNI hostname to get the destination address
#[clap(id = "sni-router-destination", long)]
dest: Option<String>,
}
pub async fn run() -> anyhow::Result<()> {
let _logging_guard = crate::logging::init().await?;
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
@@ -330,51 +307,73 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.to_string(),
port,
elasticache::CredentialsProvider::new(
args.aws_region,
args.redis_cluster_name,
args.redis_user_id,
)
.await,
),
),
(None, None) => {
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
}
};
let redis_notifications_client = if let Some(url) = args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
} else {
regional_redis_client.clone()
};
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
let http_address: SocketAddr = args.http.parse()?;
info!("Starting http on {http_address}");
let http_listener = TcpListener::bind(http_address).await?.into_std()?;
info!("Starting mgmt on {}", args.mgmt);
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
let mgmt_address: SocketAddr = args.mgmt.parse()?;
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_listener = if args.is_auth_broker {
None
} else {
info!("Starting proxy on {}", args.proxy);
Some(TcpListener::bind(args.proxy).await?)
};
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let sni_router_listeners = {
let args = &args.pg_sni_router;
if args.dest.is_some() {
ensure!(
args.tls_key.is_some(),
"sni-router-tls-key must be provided"
);
ensure!(
args.tls_cert.is_some(),
"sni-router-tls-cert must be provided"
);
info!(
"Starting pg-sni-router on {} and {}",
args.listen, args.listen_tls
);
Some((
TcpListener::bind(args.listen).await?,
TcpListener::bind(args.listen_tls).await?,
))
} else {
None
}
Some(TcpListener::bind(proxy_address).await?)
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
@@ -459,37 +458,6 @@ pub async fn run() -> anyhow::Result<()> {
}
}
// spawn pg-sni-router mode.
if let Some((listen, listen_tls)) = sni_router_listeners {
let args = args.pg_sni_router;
let dest = args.dest.expect("already asserted it is set");
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let (tls_config, tls_server_end_point) =
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
tls_server_end_point,
listen_tls,
cancellation_token.clone(),
));
}
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
@@ -597,7 +565,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.certs_dir.as_deref(),
args.certs_dir.as_ref(),
args.allow_tls_keylogfile,
)?),
(None, None) => None,
@@ -843,60 +811,6 @@ fn build_auth_backend(
}
}
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.to_string(),
port,
elasticache::CredentialsProvider::new(
args.aws_region.clone(),
args.redis_cluster_name.clone(),
args.redis_user_id.clone(),
)
.await,
),
),
(None, None) => {
// todo: upgrade to error?
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
}
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
}
#[cfg(test)]
mod tests {
use std::time::Duration;

View File

@@ -115,8 +115,8 @@ pub struct ProxyMetrics {
#[metric(metadata = Thresholds::with_buckets([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0]))]
pub allowed_vpc_endpoint_ids: Histogram<10>,
/// Number of connections, by the method we used to determine the endpoint.
pub accepted_connections_by_sni: CounterVec<SniSet>,
/// Number of connections (per sni).
pub accepted_connections_by_sni: CounterVec<StaticLabelSet<SniKind>>,
/// Number of connection failures (per kind).
pub connection_failures_total: CounterVec<StaticLabelSet<ConnectionFailureKind>>,
@@ -342,20 +342,11 @@ pub enum LatencyExclusions {
ClientCplaneComputeRetry,
}
#[derive(LabelGroup)]
#[label(set = SniSet)]
pub struct SniGroup {
pub protocol: Protocol,
pub kind: SniKind,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "kind")]
pub enum SniKind {
/// Domain name based routing. SNI for libpq/websockets. Host for HTTP
Sni,
/// Metadata based routing. `options` for libpq/websockets. Header for HTTP
NoSni,
/// Metadata based routing, using the password field.
PasswordHack,
}

View File

@@ -24,6 +24,9 @@ pub(crate) enum HandshakeError {
#[error("protocol violation")]
ProtocolViolation,
#[error("missing certificate")]
MissingCertificate,
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
@@ -39,6 +42,10 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
HandshakeError::MissingCertificate => crate::error::ErrorKind::Service,
HandshakeError::StreamUpgradeError(upgrade) => match upgrade {
StreamUpgradeError::AlreadyTls => crate::error::ErrorKind::Service,
StreamUpgradeError::Io(_) => crate::error::ErrorKind::ClientDisconnect,
@@ -139,7 +146,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// try parse endpoint
let ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names));
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
}
@@ -154,8 +161,10 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(conn_info.server_name())
.ok_or(HandshakeError::MissingCertificate)?;
stream = PqStream {
framed: Framed {

View File

@@ -98,7 +98,8 @@ fn generate_tls_config<'a>(
.with_no_client_auth()
.with_single_cert(vec![cert.clone()], key.clone_key())?;
let cert_resolver = CertResolver::new(key, vec![cert])?;
let mut cert_resolver = CertResolver::new();
cert_resolver.add_cert(key, vec![cert], true)?;
let common_names = cert_resolver.get_common_names();

View File

@@ -41,7 +41,7 @@ use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
pub(crate) const EXT_VERSION: &str = "0.3.1";
pub(crate) const EXT_VERSION: &str = "0.3.0";
pub(crate) const EXT_SCHEMA: &str = "auth";
#[derive(Clone)]

View File

@@ -56,7 +56,6 @@ use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
pub async fn task_main(
config: &'static ProxyConfig,

View File

@@ -38,7 +38,7 @@ use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
@@ -199,7 +199,8 @@ fn get_conn_info(
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
endpoint_sni(hostname, &tls.common_names)?
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')
@@ -227,32 +228,6 @@ fn get_conn_info(
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)

View File

@@ -1,12 +1,10 @@
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::Arc;
use anyhow::{Context, bail};
use itertools::Itertools;
use rustls::crypto::ring::{self, sign};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::sign::CertifiedKey;
use x509_cert::der::{Reader, SliceReader};
use super::{PG_ALPN_PROTOCOL, TlsServerEndPoint};
@@ -22,13 +20,15 @@ pub struct TlsConfig {
/// Configure TLS for the main endpoint.
pub fn configure_tls(
key_path: &Path,
cert_path: &Path,
certs_dir: Option<&Path>,
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
allow_tls_keylogfile: bool,
) -> anyhow::Result<TlsConfig> {
let mut cert_resolver = CertResolver::new();
// add default certificate
let mut cert_resolver = CertResolver::parse_new(key_path, cert_path)?;
cert_resolver.add_cert_path(key_path, cert_path, true)?;
// add extra certificates
if let Some(certs_dir) = certs_dir {
@@ -40,7 +40,11 @@ pub fn configure_tls(
let key_path = path.join("tls.key");
let cert_path = path.join("tls.crt");
if key_path.exists() && cert_path.exists() {
cert_resolver.add_cert_path(&key_path, &cert_path)?;
cert_resolver.add_cert_path(
&key_path.to_string_lossy(),
&cert_path.to_string_lossy(),
false,
)?;
}
}
}
@@ -79,42 +83,92 @@ pub fn configure_tls(
})
}
#[derive(Debug)]
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
default: (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint),
default: Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
}
impl CertResolver {
fn parse_new(key_path: &Path, cert_path: &Path) -> anyhow::Result<Self> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
Self::new(priv_key, cert_chain)
pub fn new() -> Self {
Self::default()
}
pub fn new(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<Self> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
fn add_cert_path(
&mut self,
key_path: &str,
cert_path: &str,
is_default: bool,
) -> anyhow::Result<()> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
};
let mut certs = HashMap::new();
let default = (cert.clone(), tls_server_end_point);
certs.insert(common_name, (cert, tls_server_end_point));
Ok(Self { certs, default })
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
})?
};
self.add_cert(priv_key, cert_chain, is_default)
}
fn add_cert_path(&mut self, key_path: &Path, cert_path: &Path) -> anyhow::Result<()> {
let (priv_key, cert_chain) = parse_key_cert(key_path, cert_path)?;
self.add_cert(priv_key, cert_chain)
}
fn add_cert(
pub fn add_cert(
&mut self,
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
is_default: bool,
) -> anyhow::Result<()> {
let (common_name, cert, tls_server_end_point) = process_key_cert(priv_key, cert_chain)?;
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
if is_default {
self.default = Some((cert.clone(), tls_server_end_point));
}
self.certs.insert(common_name, (cert, tls_server_end_point));
Ok(())
}
@@ -123,85 +177,12 @@ impl CertResolver {
}
}
fn parse_key_cert(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>)> {
let priv_key = {
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("Failed to read TLS keys at '{}'", key_path.display()))?;
rustls_pemfile::private_key(&mut &key_bytes[..])
.with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
.with_context(|| format!("Failed to parse TLS keys at '{}'", key_path.display()))?
};
let cert_chain_bytes = std::fs::read(cert_path).context(format!(
"Failed to read TLS cert file at '{}.'",
cert_path.display()
))?;
let cert_chain = {
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.try_collect()
.with_context(|| {
format!(
"Failed to read TLS certificate chain from bytes from file at '{}'.",
cert_path.display()
)
})?
};
Ok((priv_key, cert_chain))
}
fn process_key_cert(
priv_key: PrivateKeyDer<'static>,
cert_chain: Vec<CertificateDer<'static>>,
) -> anyhow::Result<(String, Arc<CertifiedKey>, TlsServerEndPoint)> {
let key = sign::any_supported_type(&priv_key).context("invalid private key")?;
let first_cert = &cert_chain[0];
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let certificate = SliceReader::new(first_cert)
.context("Failed to parse cerficiate")?
.decode::<x509_cert::Certificate>()
.context("Failed to parse cerficiate")?;
let common_name = certificate.tbs_certificate.subject.to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Redirect proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
Ok((common_name, cert, tls_server_end_point))
}
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(self.resolve(client_hello.server_name()).0)
self.resolve(client_hello.server_name()).map(|x| x.0)
}
}
@@ -209,7 +190,7 @@ impl CertResolver {
pub fn resolve(
&self,
server_name: Option<&str>,
) -> (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint) {
) -> Option<(Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)> {
// loop here and cut off more and more subdomains until we find
// a match to get a proper wildcard support. OTOH, we now do not
// use nested domains, so keep this simple for now.
@@ -219,17 +200,12 @@ impl CertResolver {
if let Some(mut sni_name) = server_name {
loop {
if let Some(cert) = self.certs.get(sni_name) {
return cert.clone();
return Some(cert.clone());
}
if let Some((_, rest)) = sni_name.split_once('.') {
sni_name = rest;
} else {
// The customer has some custom DNS mapping - just return
// a default certificate.
//
// This will error if the customer uses anything stronger
// than sslmode=require. That's a choice they can make.
return self.default.clone();
return None;
}
}
} else {

View File

@@ -121,20 +121,6 @@ impl Client {
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn switch_timeline_membership(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
req: &models::TimelineMembershipSwitchRequest,
) -> Result<models::TimelineMembershipSwitchResponse> {
let uri = format!(
"{}/v1/tenant/{}/timeline/{}/membership",
self.mgmt_api_endpoint, tenant_id, timeline_id
);
let resp = self.put(&uri, req).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn delete_tenant(&self, tenant_id: TenantId) -> Result<models::TenantDeleteResult> {
let uri = format!("{}/v1/tenant/{}", self.mgmt_api_endpoint, tenant_id);
let resp = self

View File

@@ -243,7 +243,8 @@ async fn timeline_pull_handler(mut request: Request<Body>) -> Result<Response<Bo
let resp =
pull_timeline::handle_request(data, conf.sk_auth_token.clone(), ca_certs, global_timelines)
.await?;
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)
}

View File

@@ -7,7 +7,6 @@ use bytes::Bytes;
use camino::Utf8PathBuf;
use chrono::{DateTime, Utc};
use futures::{SinkExt, StreamExt, TryStreamExt};
use http_utils::error::ApiError;
use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo};
use reqwest::Certificate;
use safekeeper_api::Term;
@@ -31,7 +30,7 @@ use utils::pausable_failpoint;
use crate::control_file::CONTROL_FILE_NAME;
use crate::state::{EvictionState, TimelinePersistentState};
use crate::timeline::{Timeline, TimelineError, WalResidentTimeline};
use crate::timeline::{Timeline, WalResidentTimeline};
use crate::timelines_global_map::{create_temp_timeline_dir, validate_temp_timeline};
use crate::wal_storage::open_wal_file;
use crate::{GlobalTimelines, debug_dump, wal_backup};
@@ -396,7 +395,7 @@ pub async fn handle_request(
sk_auth_token: Option<SecretString>,
ssl_ca_certs: Vec<Certificate>,
global_timelines: Arc<GlobalTimelines>,
) -> Result<PullTimelineResponse, ApiError> {
) -> Result<PullTimelineResponse> {
let existing_tli = global_timelines.get(TenantTimelineId::new(
request.tenant_id,
request.timeline_id,
@@ -412,9 +411,7 @@ pub async fn handle_request(
for ssl_ca_cert in ssl_ca_certs {
http_client = http_client.add_root_certificate(ssl_ca_cert);
}
let http_client = http_client
.build()
.map_err(|e| ApiError::InternalServerError(e.into()))?;
let http_client = http_client.build()?;
let http_hosts = request.http_hosts.clone();
@@ -446,10 +443,10 @@ pub async fn handle_request(
// offline and C comes online. Then we want a pull on C with A and B as hosts to work.
let min_required_successful = (http_hosts.len() - 1).max(1);
if statuses.len() < min_required_successful {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
bail!(
"only got {} successful status responses. required: {min_required_successful}",
statuses.len()
)));
)
}
// Find the most advanced safekeeper
@@ -468,32 +465,14 @@ pub async fn handle_request(
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);
let check_tombstone = !request.ignore_tombstone.unwrap_or_default();
match pull_timeline(
pull_timeline(
status,
safekeeper_host,
sk_auth_token,
http_client,
global_timelines,
check_tombstone,
)
.await
{
Ok(resp) => Ok(resp),
Err(e) => {
match e.downcast_ref::<TimelineError>() {
Some(TimelineError::AlreadyExists(_)) => Ok(PullTimelineResponse {
safekeeper_host: None,
}),
Some(TimelineError::CreationInProgress(_)) => {
// We don't return success here because creation might still fail.
Err(ApiError::Conflict("Creation in progress".to_owned()))
}
_ => Err(ApiError::InternalServerError(e)),
}
}
}
}
async fn pull_timeline(
@@ -502,7 +481,6 @@ async fn pull_timeline(
sk_auth_token: Option<SecretString>,
http_client: reqwest::Client,
global_timelines: Arc<GlobalTimelines>,
check_tombstone: bool,
) -> Result<PullTimelineResponse> {
let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id);
info!(
@@ -574,7 +552,7 @@ async fn pull_timeline(
// Finally, load the timeline.
let _tli = global_timelines
.load_temp_timeline(ttid, &tli_dir_path, check_tombstone)
.load_temp_timeline(ttid, &tli_dir_path, false)
.await?;
Ok(PullTimelineResponse {

View File

@@ -98,23 +98,6 @@ impl SafekeeperClient {
)
}
#[allow(unused)]
pub(crate) async fn switch_timeline_membership(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
req: &models::TimelineMembershipSwitchRequest,
) -> Result<models::TimelineMembershipSwitchResponse> {
measured_request!(
"switch_timeline_membership",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner
.switch_timeline_membership(tenant_id, timeline_id, req)
.await
)
}
pub(crate) async fn delete_tenant(
&self,
tenant_id: TenantId,

View File

@@ -3886,10 +3886,10 @@ impl Service {
None
} else if safekeepers {
// Note that for imported timelines, we do not create the timeline on the safekeepers
// straight away. Instead, we do it once the import finalized such that we know what
// start LSN to provide for the safekeepers. This is done in
// [`Self::finalize_timeline_import`].
// Note that we do not support creating the timeline on the safekeepers
// for imported timelines. The `start_lsn` of the timeline is not known
// until the import finshes.
// https://github.com/neondatabase/neon/issues/11569
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id))
@@ -3966,22 +3966,11 @@ impl Service {
let active = self.timeline_active_on_all_shards(&import).await?;
match active {
Some(timeline_info) => {
true => {
tracing::info!("Timeline became active on all shards");
if self.config.timelines_onto_safekeepers {
// Now that we know the start LSN of this timeline, create it on the
// safekeepers.
self.tenant_timeline_create_safekeepers_until_success(
import.tenant_id,
timeline_info,
)
.await?;
}
break;
}
None => {
false => {
tracing::info!("Timeline not active on all shards yet");
tokio::select! {
@@ -4015,6 +4004,9 @@ impl Service {
.range_mut(TenantShardId::tenant_range(import.tenant_id))
.for_each(|(_id, shard)| shard.importing = TimelineImportState::Idle);
// TODO(vlad): Timeline creations in import mode do not return a correct initdb lsn,
// so we can't create the timeline on the safekeepers. Fix by moving creation here.
// https://github.com/neondatabase/neon/issues/11569
tracing::info!(%import_failed, "Timeline import complete");
Ok(())
@@ -4029,16 +4021,10 @@ impl Service {
.await;
}
/// If the timeline is active on all shards, returns the [`TimelineInfo`]
/// collected from shard 0.
///
/// An error is returned if the shard layout has changed during the import.
/// This is guarded against within the storage controller and the pageserver,
/// and, therefore, unexpected.
async fn timeline_active_on_all_shards(
self: &Arc<Self>,
import: &TimelineImport,
) -> anyhow::Result<Option<TimelineInfo>> {
) -> anyhow::Result<bool> {
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
@@ -4062,17 +4048,13 @@ impl Service {
.expect("Pageservers may not be deleted while referenced");
targets.push((*tenant_shard_id, node.clone()));
} else {
return Ok(None);
return Ok(false);
}
}
targets
};
if targets.is_empty() {
anyhow::bail!("No shards found to finalize import for");
}
let results = self
.tenant_for_shards_api(
targets,
@@ -4088,17 +4070,10 @@ impl Service {
)
.await;
let all_active = results.iter().all(|res| match res {
Ok(results.into_iter().all(|res| match res {
Ok(info) => info.state == TimelineState::Active,
Err(_) => false,
});
if all_active {
// Both unwraps are validated above
Ok(Some(results.into_iter().next().unwrap().unwrap()))
} else {
Ok(None)
}
}))
}
pub(crate) async fn tenant_timeline_archival_config(
@@ -5206,8 +5181,7 @@ impl Service {
}
// We don't expect any new_shard_count shards to exist here, but drop them just in case
tenants
.retain(|id, s| !(id.tenant_id == *tenant_id && s.shard.count == *new_shard_count));
tenants.retain(|_id, s| s.shard.count != *new_shard_count);
detach_locations
};
@@ -8510,7 +8484,7 @@ impl Service {
// By default, live migrations are generous about the wait time for getting
// the secondary location up to speed. When draining, give up earlier in order
// to not stall the operation when a cold secondary is encountered.
const SECONDARY_WARMUP_TIMEOUT: Duration = Duration::from_secs(30);
const SECONDARY_WARMUP_TIMEOUT: Duration = Duration::from_secs(20);
const SECONDARY_DOWNLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal)
.secondary_warmup_timeout(SECONDARY_WARMUP_TIMEOUT)
@@ -8843,7 +8817,7 @@ impl Service {
node_id: NodeId,
cancel: CancellationToken,
) -> Result<(), OperationError> {
const SECONDARY_WARMUP_TIMEOUT: Duration = Duration::from_secs(30);
const SECONDARY_WARMUP_TIMEOUT: Duration = Duration::from_secs(20);
const SECONDARY_DOWNLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal)
.secondary_warmup_timeout(SECONDARY_WARMUP_TIMEOUT)

View File

@@ -1,9 +1,4 @@
use std::{
collections::HashMap,
str::FromStr,
sync::{Arc, atomic::AtomicU64},
time::Duration,
};
use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
use clashmap::{ClashMap, Entry};
use safekeeper_api::models::PullTimelineRequest;
@@ -174,17 +169,10 @@ pub(crate) struct ScheduleRequest {
pub(crate) kind: SafekeeperTimelineOpKind,
}
/// A way to keep ongoing/queued reconcile requests apart
#[derive(Copy, Clone, PartialEq, Eq)]
struct TokenId(u64);
type OngoingTokens = ClashMap<(TenantId, Option<TimelineId>), (CancellationToken, TokenId)>;
/// Handle to per safekeeper reconciler.
struct ReconcilerHandle {
tx: UnboundedSender<(ScheduleRequest, CancellationToken, TokenId)>,
ongoing_tokens: Arc<OngoingTokens>,
token_id_counter: AtomicU64,
tx: UnboundedSender<(ScheduleRequest, CancellationToken)>,
ongoing_tokens: Arc<ClashMap<(TenantId, Option<TimelineId>), CancellationToken>>,
cancel: CancellationToken,
}
@@ -197,28 +185,24 @@ impl ReconcilerHandle {
&self,
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
) -> (CancellationToken, TokenId) {
let token_id = self
.token_id_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let token_id = TokenId(token_id);
) -> CancellationToken {
let entry = self.ongoing_tokens.entry((tenant_id, timeline_id));
if let Entry::Occupied(entry) = &entry {
let (cancel, _) = entry.get();
let cancel: &CancellationToken = entry.get();
cancel.cancel();
}
entry.insert((self.cancel.child_token(), token_id)).clone()
entry.insert(self.cancel.child_token()).clone()
}
/// Cancel an ongoing reconciliation
fn cancel_reconciliation(&self, tenant_id: TenantId, timeline_id: Option<TimelineId>) {
if let Some((_, (cancel, _id))) = self.ongoing_tokens.remove(&(tenant_id, timeline_id)) {
if let Some((_, cancel)) = self.ongoing_tokens.remove(&(tenant_id, timeline_id)) {
cancel.cancel();
}
}
fn schedule_reconcile(&self, req: ScheduleRequest) {
let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id);
let cancel = self.new_token_slot(req.tenant_id, req.timeline_id);
let hostname = req.safekeeper.skp.host.clone();
if let Err(err) = self.tx.send((req, cancel, token_id)) {
if let Err(err) = self.tx.send((req, cancel)) {
tracing::info!("scheduling request onto {hostname} returned error: {err}");
}
}
@@ -227,14 +211,13 @@ impl ReconcilerHandle {
pub(crate) struct SafekeeperReconciler {
inner: SafekeeperReconcilerInner,
concurrency_limiter: Arc<Semaphore>,
rx: UnboundedReceiver<(ScheduleRequest, CancellationToken, TokenId)>,
rx: UnboundedReceiver<(ScheduleRequest, CancellationToken)>,
cancel: CancellationToken,
}
/// Thin wrapper over `Service` to not clutter its inherent functions
#[derive(Clone)]
struct SafekeeperReconcilerInner {
ongoing_tokens: Arc<OngoingTokens>,
service: Arc<Service>,
}
@@ -243,20 +226,15 @@ impl SafekeeperReconciler {
// We hold the ServiceInner lock so we don't want to make sending to the reconciler channel to be blocking.
let (tx, rx) = mpsc::unbounded_channel();
let concurrency = service.config.safekeeper_reconciler_concurrency;
let ongoing_tokens = Arc::new(ClashMap::new());
let mut reconciler = SafekeeperReconciler {
inner: SafekeeperReconcilerInner {
service,
ongoing_tokens: ongoing_tokens.clone(),
},
inner: SafekeeperReconcilerInner { service },
rx,
concurrency_limiter: Arc::new(Semaphore::new(concurrency)),
cancel: cancel.clone(),
};
let handle = ReconcilerHandle {
tx,
ongoing_tokens,
token_id_counter: AtomicU64::new(0),
ongoing_tokens: Arc::new(ClashMap::new()),
cancel,
};
tokio::spawn(async move { reconciler.run().await });
@@ -268,9 +246,7 @@ impl SafekeeperReconciler {
req = self.rx.recv() => req,
_ = self.cancel.cancelled() => break,
};
let Some((req, req_cancel, req_token_id)) = req else {
break;
};
let Some((req, req_cancel)) = req else { break };
let permit_res = tokio::select! {
req = self.concurrency_limiter.clone().acquire_owned() => req,
@@ -289,7 +265,7 @@ impl SafekeeperReconciler {
let timeline_id = req.timeline_id;
let node_id = req.safekeeper.skp.id;
inner
.reconcile_one(req, req_cancel, req_token_id)
.reconcile_one(req, req_cancel)
.instrument(tracing::info_span!(
"reconcile_one",
?kind,
@@ -304,14 +280,8 @@ impl SafekeeperReconciler {
}
impl SafekeeperReconcilerInner {
async fn reconcile_one(
&self,
req: ScheduleRequest,
req_cancel: CancellationToken,
req_token_id: TokenId,
) {
async fn reconcile_one(&self, req: ScheduleRequest, req_cancel: CancellationToken) {
let req_host = req.safekeeper.skp.host.clone();
let success;
match req.kind {
SafekeeperTimelineOpKind::Pull => {
let Some(timeline_id) = req.timeline_id else {
@@ -331,24 +301,20 @@ impl SafekeeperReconcilerInner {
http_hosts,
tenant_id: req.tenant_id,
timeline_id,
ignore_tombstone: Some(false),
};
success = self
.reconcile_inner(
&req,
async |client| client.pull_timeline(&pull_req).await,
|resp| {
if let Some(host) = resp.safekeeper_host {
tracing::info!("pulled timeline from {host} onto {req_host}");
} else {
tracing::info!(
"timeline already present on safekeeper on {req_host}"
);
}
},
req_cancel,
)
.await;
self.reconcile_inner(
req,
async |client| client.pull_timeline(&pull_req).await,
|resp| {
if let Some(host) = resp.safekeeper_host {
tracing::info!("pulled timeline from {host} onto {req_host}");
} else {
tracing::info!("timeline already present on safekeeper on {req_host}");
}
},
req_cancel,
)
.await;
}
SafekeeperTimelineOpKind::Exclude => {
// TODO actually exclude instead of delete here
@@ -359,23 +325,22 @@ impl SafekeeperReconcilerInner {
);
return;
};
success = self
.reconcile_inner(
&req,
async |client| client.delete_timeline(tenant_id, timeline_id).await,
|_resp| {
tracing::info!("deleted timeline from {req_host}");
},
req_cancel,
)
.await;
self.reconcile_inner(
req,
async |client| client.delete_timeline(tenant_id, timeline_id).await,
|_resp| {
tracing::info!("deleted timeline from {req_host}");
},
req_cancel,
)
.await;
}
SafekeeperTimelineOpKind::Delete => {
let tenant_id = req.tenant_id;
if let Some(timeline_id) = req.timeline_id {
success = self
let deleted = self
.reconcile_inner(
&req,
req,
async |client| client.delete_timeline(tenant_id, timeline_id).await,
|_resp| {
tracing::info!("deleted timeline from {req_host}");
@@ -383,13 +348,13 @@ impl SafekeeperReconcilerInner {
req_cancel,
)
.await;
if success {
if deleted {
self.delete_timeline_from_db(tenant_id, timeline_id).await;
}
} else {
success = self
let deleted = self
.reconcile_inner(
&req,
req,
async |client| client.delete_tenant(tenant_id).await,
|_resp| {
tracing::info!(%tenant_id, "deleted tenant from {req_host}");
@@ -397,21 +362,12 @@ impl SafekeeperReconcilerInner {
req_cancel,
)
.await;
if success {
if deleted {
self.delete_tenant_timelines_from_db(tenant_id).await;
}
}
}
}
if success {
self.ongoing_tokens.remove_if(
&(req.tenant_id, req.timeline_id),
|_ttid, (_cancel, token_id)| {
// Ensure that this request is indeed the request we just finished and not a new one
req_token_id == *token_id
},
);
}
}
async fn delete_timeline_from_db(&self, tenant_id: TenantId, timeline_id: TimelineId) {
match self
@@ -465,10 +421,10 @@ impl SafekeeperReconcilerInner {
self.delete_timeline_from_db(tenant_id, timeline_id).await;
}
}
/// Returns whether the reconciliation happened successfully (or we got cancelled)
/// Returns whether the reconciliation happened successfully
async fn reconcile_inner<T, F, U>(
&self,
req: &ScheduleRequest,
req: ScheduleRequest,
closure: impl Fn(SafekeeperClient) -> F,
log_success: impl FnOnce(T) -> U,
req_cancel: CancellationToken,

View File

@@ -323,42 +323,6 @@ impl Service {
})
}
pub(crate) async fn tenant_timeline_create_safekeepers_until_success(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_info: TimelineInfo,
) -> anyhow::Result<()> {
const BACKOFF: Duration = Duration::from_secs(5);
loop {
if self.cancel.is_cancelled() {
anyhow::bail!("Shut down requested while finalizing import");
}
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.await;
match res {
Ok(_) => {
tracing::info!("Timeline created on safekeepers");
break;
}
Err(err) => {
tracing::error!("Failed to create timeline on safekeepers: {err}");
tokio::select! {
_ = self.cancel.cancelled() => {
anyhow::bail!("Shut down requested while finalizing import");
},
_ = tokio::time::sleep(BACKOFF) => {}
};
}
}
}
Ok(())
}
/// Directly insert the timeline into the database without reconciling it with safekeepers.
///
/// Useful if the timeline already exists on the specified safekeepers,

View File

@@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::SystemTime;
use futures_util::StreamExt;
@@ -55,7 +56,7 @@ impl TimelineAnalysis {
pub(crate) async fn branch_cleanup_and_check_errors(
remote_client: &GenericRemoteStorage,
id: &TenantShardTimelineId,
tenant_objects: &mut TenantObjectListing,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
s3_active_branch: Option<&BranchData>,
console_branch: Option<BranchData>,
s3_data: Option<RemoteTimelineBlobData>,
@@ -150,7 +151,11 @@ pub(crate) async fn branch_cleanup_and_check_errors(
))
}
if !tenant_objects.check_ref(id.timeline_id, &layer, &metadata) {
if !tenant_objects
.lock()
.await
.check_ref(id.timeline_id, &layer, &metadata)
{
let path = remote_layer_path(
&id.tenant_shard_id.tenant_id,
&id.timeline_id,
@@ -165,17 +170,16 @@ pub(crate) async fn branch_cleanup_and_check_errors(
.head_object(&path, &CancellationToken::new())
.await;
if let Err(e) = response {
if response.is_err() {
// Object is not present.
let is_l0 = LayerMap::is_l0(layer.key_range(), layer.is_delta());
let msg = format!(
"index_part.json contains a layer {}{} (shard {}) that is not present in remote storage (layer_is_l0: {}) with error: {}",
"index_part.json contains a layer {}{} (shard {}) that is not present in remote storage (layer_is_l0: {})",
layer,
metadata.generation.get_suffix(),
metadata.shard,
is_l0,
e,
);
if is_l0 || ignore_error {
@@ -355,7 +359,6 @@ pub(crate) async fn list_timeline_blobs(
match res {
ListTimelineBlobsResult::Ready(data) => Ok(data),
ListTimelineBlobsResult::MissingIndexPart(_) => {
tracing::warn!("listing raced with removal of an index, retrying");
// Retry if listing raced with removal of an index
let data = list_timeline_blobs_impl(remote_client, id, root_target)
.await?
@@ -442,7 +445,7 @@ async fn list_timeline_blobs_impl(
}
if index_part_keys.is_empty() && s3_layers.is_empty() {
tracing::info!("Timeline is empty: expected post-deletion state.");
tracing::debug!("Timeline is empty: expected post-deletion state.");
if initdb_archive {
tracing::info!("Timeline is post deletion but initdb archive is still present.");
}

View File

@@ -73,8 +73,12 @@ enum Command {
node_kind: NodeKind,
#[arg(short, long, default_value_t = false)]
json: bool,
/// If provided, only these tenants will be listed from the remote storage.
#[arg(long = "tenant-id", num_args = 0..)]
tenant_ids: Vec<TenantShardId>,
/// If provided, we will list all tenants, but then filter with the prefix.
#[arg(long = "tenant-id-prefix")]
tenant_id_prefix: Option<TenantId>,
#[arg(long = "post", default_value_t = false)]
post_to_storcon: bool,
#[arg(long, default_value = None)]
@@ -178,6 +182,7 @@ async fn main() -> anyhow::Result<()> {
Command::ScanMetadata {
json,
tenant_ids,
tenant_id_prefix,
node_kind,
post_to_storcon,
dump_db_connstr,
@@ -186,6 +191,9 @@ async fn main() -> anyhow::Result<()> {
verbose,
} => {
if let NodeKind::Safekeeper = node_kind {
if tenant_id_prefix.is_some() {
bail!("`tenant_id_prefix` is not supported for safekeeper node_kind");
}
let db_or_list = match (timeline_lsns, dump_db_connstr) {
(Some(timeline_lsns), _) => {
let timeline_lsns = serde_json::from_str(&timeline_lsns)
@@ -227,6 +235,7 @@ async fn main() -> anyhow::Result<()> {
bucket_config,
controller_client.as_ref(),
tenant_ids,
tenant_id_prefix,
json,
post_to_storcon,
verbose,
@@ -338,6 +347,7 @@ pub async fn run_cron_job(
bucket_config,
controller_client,
Vec::new(),
None,
true,
post_to_storcon,
false, // default to non-verbose mode
@@ -384,10 +394,12 @@ pub async fn pageserver_physical_gc_cmd(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub async fn scan_pageserver_metadata_cmd(
bucket_config: BucketConfig,
controller_client: Option<&control_api::Client>,
tenant_shard_ids: Vec<TenantShardId>,
tenant_id_prefix: Option<TenantId>,
json: bool,
post_to_storcon: bool,
verbose: bool,
@@ -398,7 +410,14 @@ pub async fn scan_pageserver_metadata_cmd(
"Posting pageserver scan health status to storage controller requires `--controller-api` and `--controller-jwt` to run"
));
}
match scan_pageserver_metadata(bucket_config.clone(), tenant_shard_ids, verbose).await {
match scan_pageserver_metadata(
bucket_config.clone(),
tenant_shard_ids,
tenant_id_prefix,
verbose,
)
.await
{
Err(e) => {
tracing::error!("Failed: {e}");
Err(e)

View File

@@ -137,10 +137,11 @@ struct TenantRefAccumulator {
impl TenantRefAccumulator {
fn update(&mut self, ttid: TenantShardTimelineId, index_part: &IndexPart) {
let this_shard_idx = ttid.tenant_shard_id.to_index();
self.shards_seen
(*self
.shards_seen
.entry(ttid.tenant_shard_id.tenant_id)
.or_default()
.insert(this_shard_idx);
.or_default())
.insert(this_shard_idx);
let mut ancestor_refs = Vec::new();
for (layer_name, layer_metadata) in &index_part.layer_metadata {
@@ -593,7 +594,6 @@ async fn gc_timeline(
index_part_snapshot_time: _,
} => (index_part, *index_part_generation, data.unused_index_keys),
BlobDataParseResult::Relic => {
tracing::info!("Skipping timeline {ttid}, it is a relic");
// Post-deletion tenant location: don't try and GC it.
return Ok(summary);
}
@@ -767,13 +767,10 @@ pub async fn pageserver_physical_gc(
stream_tenant_timelines(remote_client_ref, target_ref, tenant_shard_id).await?,
);
Ok(try_stream! {
let mut cnt = 0;
while let Some(ttid_res) = timelines.next().await {
let ttid = ttid_res?;
cnt += 1;
yield (ttid, tenant_manifest_arc.clone());
}
tracing::info!(%tenant_shard_id, "Found {} timelines", cnt);
})
}
});
@@ -793,7 +790,6 @@ pub async fn pageserver_physical_gc(
&accumulator,
tenant_manifest_arc,
)
.instrument(info_span!("gc_timeline", %ttid))
});
let timelines = timelines.try_buffered(CONCURRENCY);
let mut timelines = std::pin::pin!(timelines);

View File

@@ -1,5 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use futures::SinkExt;
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::remote_layer_path;
use pageserver_api::controller_api::MetadataHealthUpdateRequest;
@@ -7,6 +9,7 @@ use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use serde::Serialize;
use tracing::{Instrument, info_span};
use utils::generation::Generation;
use utils::id::TenantId;
use utils::shard::ShardCount;
@@ -14,10 +17,12 @@ use crate::checks::{
BlobDataParseResult, RemoteTimelineBlobData, TenantObjectListing, TimelineAnalysis,
branch_cleanup_and_check_errors, list_timeline_blobs,
};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::metadata_stream::{
stream_tenant_timelines, stream_tenants, stream_tenants_maybe_prefix,
};
use crate::{BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, init_remote};
#[derive(Serialize, Default)]
#[derive(Serialize, Default, Clone)]
pub struct MetadataSummary {
tenant_count: usize,
timeline_count: usize,
@@ -102,13 +107,13 @@ impl MetadataSummary {
format!(
"Tenants: {}
Timelines: {}
Timeline-shards: {}
With errors: {}
With warnings: {}
With orphan layers: {}
Index versions: {version_summary}
",
Timelines: {}
Timeline-shards: {}
With errors: {}
With warnings: {}
With orphan layers: {}
Index versions: {version_summary}
",
self.tenant_count,
self.timeline_count,
self.timeline_shard_count,
@@ -138,27 +143,243 @@ Index versions: {version_summary}
pub async fn scan_pageserver_metadata(
bucket_config: BucketConfig,
tenant_ids: Vec<TenantShardId>,
tenant_id_prefix: Option<TenantId>,
verbose: bool,
) -> anyhow::Result<MetadataSummary> {
let (remote_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?;
let tenants = if tenant_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&remote_client, &target))
} else {
futures::future::Either::Right(futures::stream::iter(tenant_ids.into_iter().map(Ok)))
};
if !tenant_ids.is_empty() && tenant_id_prefix.is_some() {
anyhow::bail!("`tenant_id_prefix` is not supported when `tenant_ids` is provided");
}
// How many tenants to process in parallel. We need to be mindful of pageservers
// accessing the same per tenant prefixes, so use a lower setting than pageservers.
const CONCURRENCY: usize = 32;
// Generate a stream of TenantTimelineId
let timelines = tenants.map_ok(|t| {
tracing::info!("Found tenant: {}", t);
stream_tenant_timelines(&remote_client, &target, t)
let (mut list_tenants_tx, list_tenants_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let list_tenants = tokio::spawn(async move {
let mut cnt = 0;
if tenant_ids.is_empty() {
if let Some(tenant_id_prefix) = tenant_id_prefix {
let stream = stream_tenants_maybe_prefix(
&remote_client_inner,
&target_inner,
Some(tenant_id_prefix.to_string()),
);
let mut stream = Box::pin(stream);
while let Some(tenant) = stream.next().await {
let tenant = tenant?;
list_tenants_tx.send(tenant).await?;
cnt += 1;
}
} else {
let stream = stream_tenants(&remote_client_inner, &target_inner);
let mut stream = Box::pin(stream);
while let Some(tenant) = stream.next().await {
let tenant = tenant?;
list_tenants_tx.send(tenant).await?;
cnt += 1;
}
}
} else {
for tenant_id in tenant_ids {
list_tenants_tx.send(tenant_id).await?;
cnt += 1;
}
}
tracing::info!("list_tenants: collected {} tenants", cnt);
Ok::<_, anyhow::Error>(())
});
let (mut list_timelines_tx, list_timelines_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let list_timelines = tokio::spawn(async move {
let stream = list_tenants_rx
.map(|tenant_id| {
stream_tenant_timelines(&remote_client_inner, &target_inner, tenant_id)
})
.buffered(8)
.try_flatten();
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
list_timelines_tx.send(item).await?;
}
Ok::<_, anyhow::Error>(())
});
let (mut read_timelines_tx, read_timelines_rx) = futures::channel::mpsc::channel(1);
let remote_client_inner = remote_client.clone();
let target_inner = target.clone();
let read_timelines = tokio::spawn(async move {
let stream = list_timelines_rx
.map(|ttid| report_on_timeline(&remote_client_inner, &target_inner, ttid))
.buffered(32);
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
read_timelines_tx.send(item).await?;
}
Ok::<_, anyhow::Error>(())
});
let summary = Arc::new(tokio::sync::Mutex::new(MetadataSummary::new()));
let summary_inner = summary.clone();
let (mut consolidate_tenants_tx, consolidate_tenants_rx) = futures::channel::mpsc::channel(32);
let consolidate_tenants = tokio::spawn(async move {
// We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different
// shards in the same tenant might refer to one anothers' keys if a shard split has happened.
let mut tenant_id = None;
let mut tenant_objects = TenantObjectListing::default();
let mut tenant_timeline_results = Vec::new();
// Iterate through all the timeline results. These are in key-order, so
// all results for the same tenant will be adjacent. We accumulate these,
// and then call `analyze_tenant` to flush, when we see the next tenant ID.
let mut highest_shard_count = ShardCount::MIN;
let mut read_timelines_rx = read_timelines_rx;
while let Some(i) = read_timelines_rx.next().await {
let (ttid, data) = i;
{
let mut guard = summary_inner.lock().await;
guard.update_data(&data);
}
match tenant_id {
Some(prev_tenant_id) => {
if prev_tenant_id != ttid.tenant_shard_id.tenant_id {
// New tenant: analyze this tenant's timelines, clear accumulated tenant_timeline_results
let tenant_objects = std::mem::take(&mut tenant_objects);
let timelines = std::mem::take(&mut tenant_timeline_results);
analyze_tenant(
summary_inner.clone(),
Arc::new(tokio::sync::Mutex::new(tenant_objects)),
timelines,
highest_shard_count,
&mut consolidate_tenants_tx,
)
.await?;
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = ttid.tenant_shard_id.shard_count;
} else {
highest_shard_count =
highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
None => {
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
match &data.blob_data {
BlobDataParseResult::Parsed {
index_part: _,
index_part_generation: _index_part_generation,
s3_layers,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
BlobDataParseResult::Relic => (),
BlobDataParseResult::Incorrect {
errors: _,
s3_layers,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
}
tenant_timeline_results.push((ttid, data));
}
if !tenant_timeline_results.is_empty() {
analyze_tenant(
summary_inner.clone(),
Arc::new(tokio::sync::Mutex::new(tenant_objects)),
tenant_timeline_results,
highest_shard_count,
&mut consolidate_tenants_tx,
)
.await?;
}
Ok::<_, anyhow::Error>(())
});
let remote_client_inner = remote_client.clone();
let summary_inner = summary.clone();
let analyze_tenants = tokio::spawn(async move {
let stream = consolidate_tenants_rx
.map(|(ttid, tenant_objects, data)| {
let remote_client_inner = remote_client_inner.clone();
async move {
let generation = if let BlobDataParseResult::Parsed {
index_part: _,
index_part_generation,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
Some(*index_part_generation)
} else {
None
};
let res = branch_cleanup_and_check_errors(
&remote_client_inner,
&ttid,
tenant_objects.clone(),
None,
None,
Some(data),
)
.await;
(ttid, tenant_objects.clone(), generation, res)
}
})
.buffered(32);
let mut last_tenant = None;
let mut last_tenant_objects = None;
let mut timeline_generations = HashMap::new();
let mut stream = Box::pin(stream);
while let Some((ttid, tenant_objects, generation, res)) = stream.next().await {
if last_tenant != Some(ttid) {
if let Some(tenant_id) = last_tenant {
let timeline_generations = std::mem::take(&mut timeline_generations);
identify_orphans(
tenant_id.tenant_shard_id.tenant_id,
last_tenant_objects.take().unwrap(),
summary_inner.clone(),
&timeline_generations,
)
.await;
}
last_tenant = Some(ttid);
last_tenant_objects = Some(tenant_objects);
}
if let Some(generation) = generation {
timeline_generations.insert(ttid, generation);
}
{
let mut guard = summary_inner.lock().await;
guard.update_analysis(&ttid, &res, verbose);
}
}
if let Some(tenant_id) = last_tenant {
identify_orphans(
tenant_id.tenant_shard_id.tenant_id,
last_tenant_objects.take().unwrap(),
summary_inner.clone(),
&timeline_generations,
)
.await;
}
Ok::<_, anyhow::Error>(())
});
let timelines = timelines.try_buffered(CONCURRENCY);
let timelines = timelines.try_flatten();
// Generate a stream of S3TimelineBlobData
async fn report_on_timeline(
@@ -166,93 +387,94 @@ pub async fn scan_pageserver_metadata(
target: &RootTarget,
ttid: TenantShardTimelineId,
) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> {
tracing::info!("listing blobs for timeline: {}", ttid);
let data = list_timeline_blobs(remote_client, ttid, target).await?;
Ok((ttid, data))
}
let timelines = timelines.map_ok(|ttid| report_on_timeline(&remote_client, &target, ttid));
let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY));
// We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different
// shards in the same tenant might refer to one anothers' keys if a shard split has happened.
let mut tenant_id = None;
let mut tenant_objects = TenantObjectListing::default();
let mut tenant_timeline_results = Vec::new();
// DO NOT call any long-running tasks in this function; always route them through the channel and let
// other tokio tasks handle them.
async fn analyze_tenant(
remote_client: &GenericRemoteStorage,
tenant_id: TenantId,
summary: &mut MetadataSummary,
mut tenant_objects: TenantObjectListing,
summary: Arc<tokio::sync::Mutex<MetadataSummary>>,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
timelines: Vec<(TenantShardTimelineId, RemoteTimelineBlobData)>,
highest_shard_count: ShardCount,
verbose: bool,
) {
summary.tenant_count += 1;
let mut timeline_ids = HashSet::new();
let mut timeline_generations = HashMap::new();
for (ttid, data) in timelines {
async {
if ttid.tenant_shard_id.shard_count == highest_shard_count {
// Only analyze `TenantShardId`s with highest shard count.
// Stash the generation of each timeline, for later use identifying orphan layers
if let BlobDataParseResult::Parsed {
index_part,
index_part_generation,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
if index_part.deleted_at.is_some() {
// skip deleted timeline.
tracing::info!(
"Skip analysis of {} b/c timeline is already deleted",
ttid
);
return;
}
timeline_generations.insert(ttid, *index_part_generation);
}
// Apply checks to this timeline shard's metadata, and in the process update `tenant_objects`
// reference counts for layers across the tenant.
let analysis = branch_cleanup_and_check_errors(
remote_client,
&ttid,
&mut tenant_objects,
None,
None,
Some(data),
)
.await;
summary.update_analysis(&ttid, &analysis, verbose);
timeline_ids.insert(ttid.timeline_id);
} else {
tracing::info!(
"Skip analysis of {} b/c a lower shard count than {}",
ttid,
highest_shard_count.0,
);
}
}
.instrument(
info_span!("analyze-timeline", shard = %ttid.tenant_shard_id.shard_slug(), timeline = %ttid.timeline_id),
)
.await
output_tx: &mut futures::channel::mpsc::Sender<(
TenantShardTimelineId,
Arc<tokio::sync::Mutex<TenantObjectListing>>,
RemoteTimelineBlobData,
)>,
) -> anyhow::Result<()> {
{
let mut guard = summary.lock().await;
guard.tenant_count += 1;
}
summary.timeline_count += timeline_ids.len();
let mut timeline_ids = HashSet::new();
for (ttid, data) in timelines {
async {
if ttid.tenant_shard_id.shard_count == highest_shard_count {
// Only analyze `TenantShardId`s with highest shard count.
// Stash the generation of each timeline, for later use identifying orphan layers
if let BlobDataParseResult::Parsed {
index_part,
index_part_generation: _,
s3_layers: _,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} = &data.blob_data
{
if index_part.deleted_at.is_some() {
// skip deleted timeline.
tracing::info!("Skip analysis of {} b/c timeline is already deleted", ttid);
return Ok(());
}
}
// Apply checks to this timeline shard's metadata, and in the process update `tenant_objects`
// reference counts for layers across the tenant.
output_tx.send((ttid, tenant_objects.clone(), data)).await?;
timeline_ids.insert(ttid.timeline_id);
} else {
tracing::info!(
"Skip analysis of {} b/c a lower shard count than {}",
ttid,
highest_shard_count.0,
);
}
Ok::<_, anyhow::Error>(())
}.instrument(
info_span!("analyze-timeline", shard = %ttid.tenant_shard_id.shard_slug(), timeline = %ttid.timeline_id),
)
.await?;
}
{
let mut guard = summary.lock().await;
guard.timeline_count += timeline_ids.len();
}
Ok(())
}
async fn identify_orphans(
tenant_id: TenantId,
tenant_objects: Arc<tokio::sync::Mutex<TenantObjectListing>>,
summary: Arc<tokio::sync::Mutex<MetadataSummary>>,
timeline_generations: &HashMap<TenantShardTimelineId, Generation>,
) {
// Identifying orphan layers must be done on a tenant-wide basis, because individual
// shards' layers may be referenced by other shards.
//
// Orphan layers are not a corruption, and not an indication of a problem. They are just
// consuming some space in remote storage, and may be cleaned up at leisure.
for (shard_index, timeline_id, layer_file, generation) in tenant_objects.get_orphans() {
let orphans = { tenant_objects.lock().await.get_orphans() };
for (shard_index, timeline_id, layer_file, generation) in orphans {
let ttid = TenantShardTimelineId {
tenant_shard_id: TenantShardId {
tenant_id,
@@ -282,83 +504,20 @@ pub async fn scan_pageserver_metadata(
tracing::info!("Orphan layer detected: {orphan_path}");
summary.notify_timeline_orphan(&ttid);
{
let mut guard = summary.lock().await;
guard.notify_timeline_orphan(&ttid);
}
}
}
// Iterate through all the timeline results. These are in key-order, so
// all results for the same tenant will be adjacent. We accumulate these,
// and then call `analyze_tenant` to flush, when we see the next tenant ID.
let mut summary = MetadataSummary::new();
let mut highest_shard_count = ShardCount::MIN;
while let Some(i) = timelines.next().await {
let (ttid, data) = i?;
summary.update_data(&data);
// TODO: bail out early if any of the tasks fail
list_tenants.await??;
list_timelines.await??;
read_timelines.await??;
consolidate_tenants.await??;
analyze_tenants.await??;
match tenant_id {
Some(prev_tenant_id) => {
if prev_tenant_id != ttid.tenant_shard_id.tenant_id {
// New tenant: analyze this tenant's timelines, clear accumulated tenant_timeline_results
let tenant_objects = std::mem::take(&mut tenant_objects);
let timelines = std::mem::take(&mut tenant_timeline_results);
analyze_tenant(
&remote_client,
prev_tenant_id,
&mut summary,
tenant_objects,
timelines,
highest_shard_count,
verbose,
)
.instrument(info_span!("analyze-tenant", tenant = %prev_tenant_id))
.await;
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = ttid.tenant_shard_id.shard_count;
} else {
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
None => {
tenant_id = Some(ttid.tenant_shard_id.tenant_id);
highest_shard_count = highest_shard_count.max(ttid.tenant_shard_id.shard_count);
}
}
match &data.blob_data {
BlobDataParseResult::Parsed {
index_part: _,
index_part_generation: _index_part_generation,
s3_layers,
index_part_last_modified_time: _,
index_part_snapshot_time: _,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
BlobDataParseResult::Relic => (),
BlobDataParseResult::Incorrect {
errors: _,
s3_layers,
} => {
tenant_objects.push(ttid, s3_layers.clone());
}
}
tenant_timeline_results.push((ttid, data));
}
if !tenant_timeline_results.is_empty() {
let tenant_id = tenant_id.expect("Must be set if results are present");
analyze_tenant(
&remote_client,
tenant_id,
&mut summary,
tenant_objects,
tenant_timeline_results,
highest_shard_count,
verbose,
)
.instrument(info_span!("analyze-tenant", tenant = %tenant_id))
.await;
}
Ok(summary)
let summary = summary.lock().await;
Ok(summary.clone())
}

View File

@@ -24,6 +24,7 @@ pub struct SnapshotDownloader {
remote_client: GenericRemoteStorage,
#[allow(dead_code)]
target: RootTarget,
bucket_config: BucketConfig,
tenant_id: TenantId,
output_path: Utf8PathBuf,
concurrency: usize,
@@ -42,6 +43,7 @@ impl SnapshotDownloader {
Ok(Self {
remote_client,
target,
bucket_config,
tenant_id,
output_path,
concurrency,
@@ -216,9 +218,11 @@ impl SnapshotDownloader {
}
pub async fn download(&self) -> anyhow::Result<()> {
let (remote_client, target) =
init_remote(self.bucket_config.clone(), NodeKind::Pageserver).await?;
// Generate a stream of TenantShardId
let shards =
stream_tenant_shards(&self.remote_client, &self.target, self.tenant_id).await?;
let shards = stream_tenant_shards(&remote_client, &target, self.tenant_id).await?;
let shards: Vec<TenantShardId> = shards.try_collect().await?;
// Only read from shards that have the highest count: avoids redundantly downloading
@@ -236,8 +240,7 @@ impl SnapshotDownloader {
for shard in shards.into_iter().filter(|s| s.shard_count == shard_count) {
// Generate a stream of TenantTimelineId
let timelines =
stream_tenant_timelines(&self.remote_client, &self.target, shard).await?;
let timelines = stream_tenant_timelines(&remote_client, &target, shard).await?;
// Generate a stream of S3TimelineBlobData
async fn load_timeline_index(
@@ -248,8 +251,8 @@ impl SnapshotDownloader {
let data = list_timeline_blobs(remote_client, ttid, target).await?;
Ok((ttid, data))
}
let timelines = timelines
.map_ok(|ttid| load_timeline_index(&self.remote_client, &self.target, ttid));
let timelines =
timelines.map_ok(|ttid| load_timeline_index(&remote_client, &target, ttid));
let mut timelines = std::pin::pin!(timelines.try_buffered(8));
while let Some(i) = timelines.next().await {

Some files were not shown because too many files have changed in this diff Show More