Compare commits

..

4 Commits

Author SHA1 Message Date
Tristan Partin
d7d3fb332f Remove notion of ParsedSpec
Signed-off-by: Tristan Partin <tristan@neon.tech>
2025-06-09 11:33:46 -05:00
Tristan Partin
c37ce9b69c Clean up implementation of ComputeNode::has_feature()
Option::is_some_and() is perfect for what this function does.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2025-06-09 11:03:51 -05:00
Tristan Partin
681edf3983 Mark compute_ctl::main as #[tokio::main]
Historically, it was not the case, but we use async code extensively
within compute_ctl, so might as well make it easy for people in the
future to add async code.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2025-06-09 11:03:51 -05:00
Tristan Partin
ab898e40b0 Move get_config() to a method of Cli
We were already basically using it as a method. All inputs to the
function were a function of the CLI arguments anyway.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2025-06-09 11:03:51 -05:00
101 changed files with 1181 additions and 2125 deletions

47
Cargo.lock generated
View File

@@ -753,7 +753,6 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"headers",
"http 1.1.0",
@@ -762,8 +761,6 @@ dependencies = [
"mime",
"pin-project-lite",
"serde",
"serde_html_form",
"serde_path_to_error",
"tower 0.5.2",
"tower-layer",
"tower-service",
@@ -903,6 +900,12 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5"
[[package]]
name = "base64"
version = "0.21.7"
@@ -1275,10 +1278,13 @@ dependencies = [
"chrono",
"indexmap 2.9.0",
"jsonwebtoken",
"postgres",
"regex",
"remote_storage",
"serde",
"serde_json",
"tokio-postgres",
"url",
"utils",
]
@@ -1294,7 +1300,7 @@ dependencies = [
"aws-smithy-types",
"axum",
"axum-extra",
"base64 0.22.1",
"base64 0.13.1",
"bytes",
"camino",
"cfg-if",
@@ -1420,7 +1426,7 @@ name = "control_plane"
version = "0.1.0"
dependencies = [
"anyhow",
"base64 0.22.1",
"base64 0.13.1",
"camino",
"clap",
"comfy-table",
@@ -4812,7 +4818,7 @@ dependencies = [
name = "postgres-protocol2"
version = "0.1.0"
dependencies = [
"base64 0.22.1",
"base64 0.20.0",
"byteorder",
"bytes",
"fallible-iterator",
@@ -5184,7 +5190,7 @@ dependencies = [
"aws-config",
"aws-sdk-iam",
"aws-sigv4",
"base64 0.22.1",
"base64 0.13.1",
"bstr",
"bytes",
"camino",
@@ -6419,19 +6425,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.125"
@@ -6488,17 +6481,15 @@ dependencies = [
[[package]]
name = "serde_with"
version = "3.12.0"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa"
checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe"
dependencies = [
"base64 0.22.1",
"base64 0.13.1",
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.9.0",
"serde",
"serde_derive",
"serde_json",
"serde_with_macros",
"time",
@@ -6506,9 +6497,9 @@ dependencies = [
[[package]]
name = "serde_with_macros"
version = "3.12.0"
version = "2.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e"
checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f"
dependencies = [
"darling",
"proc-macro2",
@@ -6759,6 +6750,7 @@ dependencies = [
"chrono",
"clap",
"clashmap",
"compute_api",
"control_plane",
"cron",
"diesel",
@@ -8579,6 +8571,7 @@ dependencies = [
"anyhow",
"axum",
"axum-core",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
"bytes",

View File

@@ -71,8 +71,8 @@ aws-credential-types = "1.2.0"
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
aws-types = "1.3"
axum = { version = "0.8.1", features = ["ws"] }
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
base64 = "0.22"
axum-extra = { version = "0.10.0", features = ["typed-header"] }
base64 = "0.13.0"
bincode = "1.3"
bindgen = "0.71"
bit_field = "0.10.2"
@@ -171,7 +171,7 @@ sentry = { version = "0.37", default-features = false, features = ["backtrace",
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
serde_path_to_error = "0.1"
serde_with = { version = "3", features = [ "base64" ] }
serde_with = { version = "2.0", features = [ "base64" ] }
serde_assert = "0.5.0"
sha2 = "0.10.2"
signal-hook = "0.3"

View File

@@ -145,28 +145,47 @@ impl Cli {
}
}
fn main() -> Result<()> {
impl Cli {
pub fn get_config(&self) -> Result<ComputeConfig> {
// First, read the config from the path if provided
if let Some(ref config) = self.config {
let file = File::open(config)?;
return Ok(serde_json::from_reader(&file)?);
}
// If the config wasn't provided in the CLI arguments, then retrieve it from
// the control plane
match get_config_from_control_plane(
self.control_plane_uri.as_ref().unwrap(),
&self.compute_id,
) {
Ok(config) => Ok(config),
Err(e) => {
error!(
"cannot get response from control plane: {}\n\
neither spec nor confirmation that compute is in the Empty state was received",
e
);
Err(e)
}
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
let scenario = failpoint_support::init();
// For historical reasons, the main thread that processes the config and launches postgres
// is synchronous, but we always have this tokio runtime available and we "enter" it so
// that you can use tokio::spawn() and tokio::runtime::Handle::current().block_on(...)
// from all parts of compute_ctl.
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let _rt_guard = runtime.enter();
runtime.block_on(init())?;
init().await?;
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?;
let config = get_config(&cli)?;
let config = cli.get_config()?;
let compute_node = ComputeNode::new(
ComputeNodeParams {
@@ -191,7 +210,7 @@ fn main() -> Result<()> {
config,
)?;
let exit_code = compute_node.run()?;
let exit_code = compute_node.run().await?;
scenario.teardown();
@@ -213,28 +232,6 @@ async fn init() -> Result<()> {
Ok(())
}
fn get_config(cli: &Cli) -> Result<ComputeConfig> {
// First, read the config from the path if provided
if let Some(ref config) = cli.config {
let file = File::open(config)?;
return Ok(serde_json::from_reader(&file)?);
}
// If the config wasn't provided in the CLI arguments, then retrieve it from
// the control plane
match get_config_from_control_plane(cli.control_plane_uri.as_ref().unwrap(), &cli.compute_id) {
Ok(config) => Ok(config),
Err(e) => {
error!(
"cannot get response from control plane: {}\n\
neither spec nor confirmation that compute is in the Empty state was received",
e
);
Err(e)
}
}
}
fn deinit_and_exit(exit_code: Option<i32>) -> ! {
// Shutdown trace pipeline gracefully, so that it has a chance to send any
// pending traces before we exit. Shutting down OTEL tracing provider may

View File

@@ -15,12 +15,8 @@ use itertools::Itertools;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
@@ -30,9 +26,9 @@ use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tokio_postgres::{NoTls, error::SqlState};
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
@@ -144,7 +140,7 @@ pub struct ComputeState {
/// Compute spec. This can be received from the CLI or - more likely -
/// passed by the control plane with a /configure HTTP request.
pub pspec: Option<ParsedSpec>,
pub spec: Option<ComputeSpec>,
/// If the spec is passed by a /configure request, 'startup_span' is the
/// /configure request's tracing span. The main thread enters it when it
@@ -171,7 +167,7 @@ impl ComputeState {
status: ComputeStatus::Empty,
last_active: None,
error: None,
pspec: None,
spec: None,
startup_span: None,
metrics: ComputeMetrics::default(),
lfc_prewarm_state: LfcPrewarmState::default(),
@@ -203,94 +199,6 @@ impl Default for ComputeState {
}
}
#[derive(Clone, Debug)]
pub struct ParsedSpec {
pub spec: ComputeSpec,
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
pub endpoint_storage_addr: Option<SocketAddr>,
pub endpoint_storage_token: Option<String>,
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
// Extract the options from the spec file that are needed to connect to
// the storage system.
//
// For backwards-compatibility, the top-level fields in the spec file
// may be empty. In that case, we need to dig them from the GUCs in the
// cluster.settings field.
let pageserver_connstr = spec
.pageserver_connstring
.clone()
.or_else(|| spec.cluster.settings.find("neon.pageserver_connstring"))
.ok_or("pageserver connstr should be provided")?;
let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() {
if matches!(spec.mode, ComputeMode::Primary) {
spec.cluster
.settings
.find("neon.safekeepers")
.ok_or("safekeeper connstrings should be provided")?
.split(',')
.map(|str| str.to_string())
.collect()
} else {
vec![]
}
} else {
spec.safekeeper_connstrings.clone()
};
let storage_auth_token = spec.storage_auth_token.clone();
let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id {
tenant_id
} else {
spec.cluster
.settings
.find("neon.tenant_id")
.ok_or("tenant id should be provided")
.map(|s| TenantId::from_str(&s))?
.or(Err("invalid tenant id"))?
};
let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id {
timeline_id
} else {
spec.cluster
.settings
.find("neon.timeline_id")
.ok_or("timeline id should be provided")
.map(|s| TimelineId::from_str(&s))?
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<SocketAddr> = spec
.endpoint_storage_addr
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"))
.unwrap_or_default()
.parse()
.ok();
let endpoint_storage_token = spec
.endpoint_storage_token
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_token"));
Ok(ParsedSpec {
spec,
pageserver_connstr,
safekeeper_connstrings,
storage_auth_token,
tenant_id,
timeline_id,
endpoint_storage_addr,
endpoint_storage_token,
})
}
}
/// If we are a VM, returns a [`Command`] that will run in the `neon-postgres`
/// cgroup. Otherwise returns the default `Command::new(cmd)`
///
@@ -368,10 +276,7 @@ impl ComputeNode {
tokio_conn_conf.options(&options);
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
}
new_state.spec = config.spec;
Ok(ComputeNode {
params,
@@ -386,10 +291,10 @@ impl ComputeNode {
/// Top-level control flow of compute_ctl. Returns a process exit code we should
/// exit with.
pub fn run(self) -> Result<Option<i32>> {
pub async fn run(self) -> Result<Option<i32>> {
let this = Arc::new(self);
let cli_spec = this.state.lock().unwrap().pspec.clone();
let cli_spec = this.state.lock().unwrap().spec.clone();
// If this is a pooled VM, prewarm before starting HTTP server and becoming
// available for binding. Prewarming helps Postgres start quicker later,
@@ -425,7 +330,7 @@ impl ComputeNode {
// If we got a spec from the CLI already, use that. Otherwise wait for the
// control plane to pass it to us with a /configure HTTP request
let pspec = if let Some(cli_spec) = cli_spec {
let spec = if let Some(cli_spec) = cli_spec {
cli_spec
} else {
this.wait_spec()?
@@ -438,11 +343,11 @@ impl ComputeNode {
let mut vm_monitor = None;
let mut pg_process: Option<PostgresHandle> = None;
match this.start_compute(&mut pg_process) {
match this.start_compute(&mut pg_process).await {
Ok(()) => {
// Success! Launch remaining services (just vm-monitor currently)
vm_monitor =
Some(this.start_vm_monitor(pspec.spec.disable_lfc_resizing.unwrap_or(false)));
Some(this.start_vm_monitor(spec.disable_lfc_resizing.unwrap_or(false)));
}
Err(err) => {
// Something went wrong with the startup. Log it and expose the error to
@@ -487,7 +392,7 @@ impl ComputeNode {
}
// Reap the postgres process
delay_exit |= this.cleanup_after_postgres_exit()?;
delay_exit |= this.cleanup_after_postgres_exit().await?;
// If launch failed, keep serving HTTP requests for a while, so the cloud
// control plane can get the actual error.
@@ -498,7 +403,7 @@ impl ComputeNode {
Ok(exit_code)
}
pub fn wait_spec(&self) -> Result<ParsedSpec> {
pub fn wait_spec(&self) -> Result<ComputeSpec> {
info!("no compute spec provided, waiting");
let mut state = self.state.lock().unwrap();
while state.status != ComputeStatus::ConfigurationPending {
@@ -506,7 +411,7 @@ impl ComputeNode {
}
info!("got spec, continue configuration");
let spec = state.pspec.as_ref().unwrap().clone();
let spec = state.spec.as_ref().unwrap().clone();
// Record for how long we slept waiting for the spec.
let now = Utc::now();
@@ -539,7 +444,7 @@ impl ComputeNode {
///
/// Note that this is in the critical path of a compute cold start. Keep this fast.
/// Try to do things concurrently, to hide the latencies.
fn start_compute(self: &Arc<Self>, pg_handle: &mut Option<PostgresHandle>) -> Result<()> {
async fn start_compute(self: &Arc<Self>, pg_handle: &mut Option<PostgresHandle>) -> Result<()> {
let compute_state: ComputeState;
let start_compute_span;
@@ -574,18 +479,17 @@ impl ComputeNode {
compute_state = state_guard.clone()
}
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = compute_state.spec.as_ref().expect("spec must be set");
info!(
"starting compute for project {}, operation {}, tenant {}, timeline {}, project {}, branch {}, endpoint {}, features {:?}, spec.remote_extensions {:?}",
pspec.spec.cluster.cluster_id.as_deref().unwrap_or("None"),
pspec.spec.operation_uuid.as_deref().unwrap_or("None"),
pspec.tenant_id,
pspec.timeline_id,
pspec.spec.project_id.as_deref().unwrap_or("None"),
pspec.spec.branch_id.as_deref().unwrap_or("None"),
pspec.spec.endpoint_id.as_deref().unwrap_or("None"),
pspec.spec.features,
pspec.spec.remote_extensions,
"starting compute for operation {}, tenant {}, timeline {}, project {}, branch {}, endpoint {}, features {:?}, spec.remote_extensions {:?}",
spec.operation_uuid.as_deref().unwrap_or("None"),
spec.tenant_id,
spec.timeline_id,
spec.project_id,
spec.branch_id,
spec.endpoint_id,
spec.features,
spec.remote_extensions,
);
////// PRE-STARTUP PHASE: things that need to be finished before we start the Postgres process
@@ -606,8 +510,8 @@ impl ComputeNode {
let tls_config = self.tls_config(&pspec.spec);
// If there are any remote extensions in shared_preload_libraries, start downloading them
if pspec.spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), pspec.spec.clone());
if spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), spec.clone());
pre_tasks.spawn(async move {
this.download_preload_extensions(&spec)
.in_current_span()
@@ -618,13 +522,11 @@ impl ComputeNode {
// Prepare pgdata directory. This downloads the basebackup, among other things.
{
let (this, cs) = (self.clone(), compute_state.clone());
pre_tasks.spawn_blocking_child(move || this.prepare_pgdata(&cs));
pre_tasks.spawn(async move { this.prepare_pgdata(&cs).await });
}
// Resize swap to the desired size if the compute spec says so
if let (Some(size_bytes), true) =
(pspec.spec.swap_size_bytes, self.params.resize_swap_on_bind)
{
if let (Some(size_bytes), true) = (spec.swap_size_bytes, self.params.resize_swap_on_bind) {
pre_tasks.spawn_blocking_child(move || {
// To avoid 'swapoff' hitting postgres startup, we need to run resize-swap to completion
// *before* starting postgres.
@@ -642,7 +544,7 @@ impl ComputeNode {
// Set disk quota if the compute spec says so
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) = (
pspec.spec.disk_quota_bytes,
spec.disk_quota_bytes,
self.params.set_disk_quota_for_fs.as_ref(),
) {
let disk_quota_fs_mountpoint = disk_quota_fs_mountpoint.clone();
@@ -657,7 +559,7 @@ impl ComputeNode {
}
// tune pgbouncer
if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings {
if let Some(pgbouncer_settings) = &spec.pgbouncer_settings {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
@@ -675,7 +577,7 @@ impl ComputeNode {
}
// configure local_proxy
if let Some(local_proxy) = &pspec.spec.local_proxy_config {
if let Some(local_proxy) = &spec.local_proxy_config {
info!("configuring local_proxy");
// Spawn a background task to do the configuration,
@@ -693,7 +595,7 @@ impl ComputeNode {
}
// Configure and start rsyslog for compliance audit logging
match pspec.spec.audit_log_level {
match spec.audit_log_level {
ComputeAudit::Hipaa | ComputeAudit::Extended | ComputeAudit::Full => {
let remote_endpoint =
std::env::var("AUDIT_LOGGING_ENDPOINT").unwrap_or("".to_string());
@@ -704,16 +606,10 @@ impl ComputeNode {
let log_directory_path = Path::new(&self.params.pgdata).join("log");
let log_directory_path = log_directory_path.to_string_lossy().to_string();
// Add project_id,endpoint_id to identify the logs.
//
// These ids are passed from cplane,
let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or("");
let project_id = pspec.spec.project_id.as_deref().unwrap_or("");
configure_audit_rsyslog(
log_directory_path.clone(),
endpoint_id,
project_id,
&spec.endpoint_id,
&spec.project_id,
&remote_endpoint,
)?;
@@ -724,7 +620,7 @@ impl ComputeNode {
}
// Configure and start rsyslog for Postgres logs export
let conf = PostgresLogsRsyslogConfig::new(pspec.spec.logs_export_host.as_deref());
let conf = PostgresLogsRsyslogConfig::new(spec.logs_export_host.as_deref());
configure_postgres_logs_export(conf)?;
// Launch remaining service threads
@@ -732,21 +628,20 @@ impl ComputeNode {
let _configurator_handle = launch_configurator(self);
// Wait for all the pre-tasks to finish before starting postgres
let rt = tokio::runtime::Handle::current();
while let Some(res) = rt.block_on(pre_tasks.join_next()) {
while let Some(res) = pre_tasks.join_next().await {
res??;
}
////// START POSTGRES
let start_time = Utc::now();
let pg_process = self.start_postgres(pspec.storage_auth_token.clone())?;
let pg_process = self.start_postgres(spec.storage_auth_token.clone())?;
let postmaster_pid = pg_process.pid();
*pg_handle = Some(pg_process);
// If this is a primary endpoint, perform some post-startup configuration before
// opening it up for the world.
let config_time = Utc::now();
if pspec.spec.mode == ComputeMode::Primary {
if spec.mode == ComputeMode::Primary {
self.configure_as_primary(&compute_state)?;
let conf = self.get_tokio_conn_conf(None);
@@ -785,8 +680,9 @@ impl ComputeNode {
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
self.prewarm_lfc(None);
self.prewarm_lfc();
}
Ok(())
}
@@ -862,14 +758,14 @@ impl ComputeNode {
}
}
fn cleanup_after_postgres_exit(&self) -> Result<bool> {
async fn cleanup_after_postgres_exit(&self) -> Result<bool> {
// Maybe sync safekeepers again, to speed up next startup
let compute_state = self.state.lock().unwrap().clone();
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
if matches!(pspec.spec.mode, compute_api::spec::ComputeMode::Primary) {
let spec = compute_state.spec.as_ref().expect("spec must be set");
if matches!(spec.mode, compute_api::spec::ComputeMode::Primary) {
info!("syncing safekeepers on shutdown");
let storage_auth_token = pspec.storage_auth_token.clone();
let lsn = self.sync_safekeepers(storage_auth_token)?;
let storage_auth_token = spec.storage_auth_token.clone();
let lsn = self.sync_safekeepers(storage_auth_token).await?;
info!("synced safekeepers at lsn {lsn}");
}
@@ -892,13 +788,12 @@ impl ComputeNode {
/// Check that compute node has corresponding feature enabled.
pub fn has_feature(&self, feature: ComputeFeature) -> bool {
let state = self.state.lock().unwrap();
if let Some(s) = state.pspec.as_ref() {
s.spec.features.contains(&feature)
} else {
false
}
self.state
.lock()
.unwrap()
.spec
.as_ref()
.is_some_and(|spec| spec.features.contains(&feature))
}
pub fn set_status(&self, status: ComputeStatus) {
@@ -915,13 +810,15 @@ impl ComputeNode {
self.state.lock().unwrap().status
}
pub fn get_timeline_id(&self) -> Option<TimelineId> {
pub fn get_timeline_id(&self) -> TimelineId {
self.state
.lock()
.unwrap()
.pspec
.spec
.as_ref()
.map(|s| s.timeline_id)
.unwrap()
.timeline_id
.clone()
}
// Remove `pgdata` directory and create it again with right permissions.
@@ -940,11 +837,10 @@ impl ComputeNode {
// unarchive it to `pgdata` directory overriding all its previous content.
#[instrument(skip_all, fields(%lsn))]
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = compute_state.spec.as_ref().expect("spec must be set");
let start_time = Instant::now();
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let mut config = postgres::Config::from_str(shard0_connstr)?;
let mut config = postgres::Config::from(&spec.pageservers[0]);
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
@@ -956,20 +852,17 @@ impl ComputeNode {
}
config.application_name("compute_ctl");
if let Some(spec) = &compute_state.pspec {
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
if let Some(spec) = &compute_state.spec {
config.options(&format!("-c neon.compute_mode={}", spec.mode.to_type_str()));
}
// Connect to pageserver
let mut client = config.connect(NoTls)?;
let mut client = config.connect(postgres::NoTls)?;
let pageserver_connect_micros = start_time.elapsed().as_micros() as u64;
let basebackup_cmd = match lsn {
Lsn(0) => {
if spec.spec.mode != ComputeMode::Primary {
if spec.mode != ComputeMode::Primary {
format!(
"basebackup {} {} --gzip --replica",
spec.tenant_id, spec.timeline_id
@@ -979,7 +872,7 @@ impl ComputeNode {
}
}
_ => {
if spec.spec.mode != ComputeMode::Primary {
if spec.mode != ComputeMode::Primary {
format!(
"basebackup {} {} {} --gzip --replica",
spec.tenant_id, spec.timeline_id, lsn
@@ -1055,35 +948,34 @@ impl ComputeNode {
compute_state: &ComputeState,
) -> Result<Option<Lsn>> {
// Construct a connection config for each safekeeper
let pspec: ParsedSpec = compute_state
.pspec
let spec = compute_state
.spec
.as_ref()
.expect("spec must be set")
.clone();
let sk_connstrs: Vec<String> = pspec.safekeeper_connstrings.clone();
let sk_configs = sk_connstrs.into_iter().map(|connstr| {
// Format connstr
let id = connstr.clone();
let connstr = format!("postgresql://no_user@{}", connstr);
let options = format!(
"-c timeline_id={} tenant_id={}",
pspec.timeline_id, pspec.tenant_id
);
let safekeepers = spec
.safekeepers
.iter()
.map(|s| {
let mut config = tokio_postgres::Config::from(s);
// Construct client
let mut config = tokio_postgres::Config::from_str(&connstr).unwrap();
config.options(&options);
if let Some(storage_auth_token) = pspec.storage_auth_token.clone() {
config.password(storage_auth_token);
}
config.user("no_user");
config.options(&format!(
"-c timeline_id={} tenant_id={}",
spec.timeline_id, spec.tenant_id
));
if let Some(storage_auth_token) = &spec.storage_auth_token {
config.password(storage_auth_token);
}
(id, config)
});
(format!("{}:{}", s.host, s.port), config)
})
.collect::<Vec<(String, tokio_postgres::Config)>>();
// Create task set to query all safekeepers
let mut tasks = FuturesUnordered::new();
let quorum = sk_configs.len() / 2 + 1;
for (id, config) in sk_configs {
let quorum = safekeepers.len() / 2 + 1;
for (id, config) in safekeepers {
let timeout = tokio::time::Duration::from_millis(100);
let task = tokio::time::timeout(timeout, ping_safekeeper(id, config));
tasks.push(tokio::spawn(task));
@@ -1128,11 +1020,13 @@ impl ComputeNode {
// Fast path for sync_safekeepers. If they're already synced we get the lsn
// in one roundtrip. If not, we should do a full sync_safekeepers.
#[instrument(skip_all)]
pub fn check_safekeepers_synced(&self, compute_state: &ComputeState) -> Result<Option<Lsn>> {
pub async fn check_safekeepers_synced(
&self,
compute_state: &ComputeState,
) -> Result<Option<Lsn>> {
let start_time = Utc::now();
let rt = tokio::runtime::Handle::current();
let result = rt.block_on(self.check_safekeepers_synced_async(compute_state));
let result = self.check_safekeepers_synced_async(compute_state).await;
// Record runtime
self.state.lock().unwrap().metrics.sync_sk_check_ms = Utc::now()
@@ -1146,7 +1040,7 @@ impl ComputeNode {
// Run `postgres` in a special mode with `--sync-safekeepers` argument
// and return the reported LSN back to the caller.
#[instrument(skip_all)]
pub fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
pub async fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
let start_time = Utc::now();
let mut sync_handle = maybe_cgexec(&self.params.pgbin)
@@ -1178,8 +1072,8 @@ impl ComputeNode {
SYNC_SAFEKEEPERS_PID.store(0, Ordering::SeqCst);
// Process has exited, so we can join the logs thread.
let _ = tokio::runtime::Handle::current()
.block_on(logs_handle)
let _ = logs_handle
.await
.map_err(|e| tracing::error!("log task panicked: {:?}", e));
if !sync_output.status.success() {
@@ -1205,9 +1099,8 @@ impl ComputeNode {
/// Do all the preparations like PGDATA directory creation, configuration,
/// safekeepers sync, basebackup, etc.
#[instrument(skip_all)]
pub fn prepare_pgdata(&self, compute_state: &ComputeState) -> Result<()> {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = &pspec.spec;
pub async fn prepare_pgdata(&self, compute_state: &ComputeState) -> Result<()> {
let spec = compute_state.spec.as_ref().expect("spec must be set");
let pgdata_path = Path::new(&self.params.pgdata);
let tls_config = self.tls_config(&pspec.spec);
@@ -1216,7 +1109,7 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&pspec.spec,
spec,
self.params.internal_http_port,
tls_config,
)?;
@@ -1227,11 +1120,13 @@ impl ComputeNode {
let lsn = match spec.mode {
ComputeMode::Primary => {
info!("checking if safekeepers are synced");
let lsn = if let Ok(Some(lsn)) = self.check_safekeepers_synced(compute_state) {
let lsn = if let Ok(Some(lsn)) = self.check_safekeepers_synced(compute_state).await
{
lsn
} else {
info!("starting safekeepers syncing");
self.sync_safekeepers(pspec.storage_auth_token.clone())
self.sync_safekeepers(spec.storage_auth_token.clone())
.await
.with_context(|| "failed to sync safekeepers")?
};
info!("safekeepers synced at LSN {}", lsn);
@@ -1248,13 +1143,13 @@ impl ComputeNode {
};
info!(
"getting basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
"getting basebackup@{} from pageserver {}:{}",
lsn, spec.pageservers[0].host, spec.pageservers[0].port
);
self.get_basebackup(compute_state, lsn).with_context(|| {
format!(
"failed to get basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
"failed to get basebackup@{} from pageserver {}:{}",
lsn, spec.pageservers[0].host, spec.pageservers[0].port
)
})?;
@@ -1536,10 +1431,9 @@ impl ComputeNode {
let conf = Arc::new(conf);
let spec = Arc::new(
compute_state
.pspec
.spec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
@@ -1608,7 +1502,7 @@ impl ComputeNode {
/// as it's used to reconfigure a previously started and configured Postgres node.
#[instrument(skip_all)]
pub fn reconfigure(&self) -> Result<()> {
let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec;
let spec = self.state.lock().unwrap().spec.as_ref().unwrap().clone();
let tls_config = self.tls_config(&spec);
@@ -1690,10 +1584,10 @@ impl ComputeNode {
#[instrument(skip_all)]
pub fn configure_as_primary(&self, compute_state: &ComputeState) -> Result<()> {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = compute_state.spec.as_ref().expect("spec must be set");
assert!(pspec.spec.mode == ComputeMode::Primary);
if !pspec.spec.skip_pg_catalog_updates {
assert!(spec.mode == ComputeMode::Primary);
if !spec.skip_pg_catalog_updates {
let pgdata_path = Path::new(&self.params.pgdata);
// temporarily reset max_cluster_size in config
// to avoid the possibility of hitting the limit, while we are applying config:
@@ -2189,24 +2083,23 @@ LIMIT 100",
/// the pageserver connection strings has changed.
///
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
pub fn wait_timeout_while_pageservers_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_connstr = state
.pspec
let old_pageservers = state
.spec
.as_ref()
.expect("spec must be set")
.pageserver_connstr
.pageservers
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
let current_pageservers = &s.spec.as_ref().expect("spec must be set").pageservers;
unchanged = current_pageservers
.iter()
.zip(&old_pageservers)
.all(|(c, o)| c == o);
unchanged
})
.unwrap();

View File

@@ -3,6 +3,7 @@ use anyhow::{Context, Result, bail};
use async_compression::tokio::bufread::{ZstdDecoder, ZstdEncoder};
use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use compute_api::spec::ComputeSpec;
use http::StatusCode;
use reqwest::Client;
use std::sync::Arc;
@@ -25,29 +26,30 @@ struct EndpointStoragePair {
}
const KEY: &str = "lfc_state";
impl EndpointStoragePair {
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
/// If not None, takes precedence over pspec.spec.endpoint_id
fn from_spec_and_endpoint(
pspec: &crate::compute::ParsedSpec,
endpoint_id: Option<String>,
) -> Result<Self> {
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
let Some(ref endpoint_id) = endpoint_id else {
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
impl TryFrom<&ComputeSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(spec: &ComputeSpec) -> Result<Self, Self::Error> {
let Some(ref addr) = spec.endpoint_storage_addr else {
bail!("spec.endpoint_storage_addr missing")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
};
let tenant_id = pspec.tenant_id;
let timeline_id = pspec.timeline_id;
let url = format!("http://{base_uri}/{tenant_id}/{timeline_id}/{endpoint_id}/{KEY}");
let Some(ref token) = pspec.endpoint_storage_token else {
bail!("pspec.endpoint_storage_token missing")
let url = format!(
"http://{addr}/{tenant_id}/{timeline_id}/{endpoint_id}/{key}",
addr = addr,
tenant_id = spec.tenant_id,
timeline_id = spec.timeline_id,
endpoint_id = spec.endpoint_id,
key = KEY
);
let Some(ref token) = spec.endpoint_storage_token else {
bail!("spec.endpoint_storage_token missing")
};
let token = token.clone();
Ok(EndpointStoragePair { url, token })
Ok(EndpointStoragePair {
url,
token: token.clone(),
})
}
}
@@ -89,7 +91,7 @@ impl ComputeNode {
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -102,7 +104,7 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
let Err(err) = cloned.prewarm_impl().await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
@@ -114,14 +116,13 @@ impl ComputeNode {
true
}
/// from_endpoint: None for endpoint managed by this compute_ctl
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
state.spec.as_ref().unwrap().try_into()
}
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
@@ -179,7 +180,7 @@ impl ComputeNode {
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();

View File

@@ -56,13 +56,24 @@ pub fn write_postgres_conf(
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
if !spec.pageservers.is_empty() {
writeln!(
file,
"neon.pageserver_connstring={}",
escape_conf_value(
&spec
.pageservers
.iter()
.map(|p| format!("host={} port={}", p.host, p.port))
.collect::<Vec<_>>()
.join(",")
)
)?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() {
if !spec.safekeepers.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(
"safekeepers_connstrings is not zero, gen: {:?}",
@@ -72,32 +83,45 @@ pub fn write_postgres_conf(
if let Some(generation) = spec.safekeepers_generation {
write!(neon_safekeepers_value, "g#{}:", generation)?;
}
neon_safekeepers_value.push_str(&spec.safekeeper_connstrings.join(","));
neon_safekeepers_value.push_str(
&spec
.safekeepers
.iter()
.map(|s| format!("{}:{}", s.host.to_string(), s.port))
.collect::<Vec<_>>()
.join(","),
);
writeln!(
file,
"neon.safekeepers={}",
escape_conf_value(&neon_safekeepers_value)
)?;
}
if let Some(s) = &spec.tenant_id {
writeln!(file, "neon.tenant_id={}", escape_conf_value(&s.to_string()))?;
}
if let Some(s) = &spec.timeline_id {
writeln!(
file,
"neon.timeline_id={}",
escape_conf_value(&s.to_string())
)?;
}
if let Some(s) = &spec.project_id {
writeln!(file, "neon.project_id={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.branch_id {
writeln!(file, "neon.branch_id={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.endpoint_id {
writeln!(file, "neon.endpoint_id={}", escape_conf_value(s))?;
}
writeln!(
file,
"neon.tenant_id={}",
escape_conf_value(&spec.tenant_id.to_string())
)?;
writeln!(
file,
"neon.timeline_id={}",
escape_conf_value(&spec.timeline_id.to_string())
)?;
writeln!(
file,
"neon.project_id={}",
escape_conf_value(&spec.project_id)
)?;
writeln!(
file,
"neon.branch_id={}",
escape_conf_value(&spec.branch_id)
)?;
writeln!(
file,
"neon.endpoint_id={}",
escape_conf_value(&spec.endpoint_id)
)?;
// tls
if let Some(tls_config) = tls_config {

View File

@@ -8,7 +8,7 @@ use http::StatusCode;
use tokio::task;
use tracing::info;
use crate::compute::{ComputeNode, ParsedSpec};
use crate::compute::ComputeNode;
use crate::http::JsonResponse;
use crate::http::extract::Json;
@@ -22,11 +22,6 @@ pub(in crate::http) async fn configure(
State(compute): State<Arc<ComputeNode>>,
request: Json<ConfigurationRequest>,
) -> Response {
let pspec = match ParsedSpec::try_from(request.0.spec) {
Ok(p) => p,
Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e),
};
// XXX: wrap state update under lock in a code block. Otherwise, we will try
// to `Send` `mut state` into the spawned thread bellow, which will cause
// the following rustc error:
@@ -43,7 +38,7 @@ pub(in crate::http) async fn configure(
// configure request for tracing purposes.
state.startup_span = Some(tracing::Span::current());
state.pspec = Some(pspec);
state.spec = Some(request.spec.clone());
state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed);
drop(state);
}

View File

@@ -31,8 +31,7 @@ pub(in crate::http) async fn download_extension(
let ext = {
let state = compute.state.lock().unwrap();
let pspec = state.pspec.as_ref().unwrap();
let spec = &pspec.spec;
let spec = &state.spec.as_ref().unwrap();
let remote_extensions = match spec.remote_extensions.as_ref() {
Some(r) => r,

View File

@@ -2,7 +2,6 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
@@ -17,16 +16,8 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
Json(compute.lfc_offload_state())
}
#[derive(serde::Deserialize)]
pub struct PrewarmQuery {
pub from_endpoint: String,
}
pub(in crate::http) async fn prewarm(
compute: Compute,
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
) -> Response {
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(

View File

@@ -21,14 +21,8 @@ impl From<&ComputeState> for ComputeStatusResponse {
fn from(state: &ComputeState) -> Self {
ComputeStatusResponse {
start_time: state.start_time,
tenant: state
.pspec
.as_ref()
.map(|pspec| pspec.tenant_id.to_string()),
timeline: state
.pspec
.as_ref()
.map(|pspec| pspec.timeline_id.to_string()),
tenant: state.spec.as_ref().map(|spec| spec.tenant_id.to_string()),
timeline: state.spec.as_ref().map(|spec| spec.timeline_id.to_string()),
status: state.status,
last_active: state.last_active,
error: state.error.clone(),

View File

@@ -18,8 +18,8 @@ use crate::compute::ComputeNode;
pub fn launch_lsn_lease_bg_task_for_static(compute: &Arc<ComputeNode>) {
let (tenant_id, timeline_id, lsn) = {
let state = compute.state.lock().unwrap();
let spec = state.pspec.as_ref().expect("Spec must be set");
match spec.spec.mode {
let spec = state.spec.as_ref().expect("Spec must be set");
match spec.mode {
ComputeMode::Static(lsn) => (spec.tenant_id, spec.timeline_id, lsn),
_ => return,
}
@@ -58,7 +58,7 @@ fn lsn_lease_bg_task(
"Request succeeded, sleeping for {} seconds",
sleep_duration.as_secs()
);
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
compute.wait_timeout_while_pageservers_unchanged(sleep_duration);
}
}
@@ -79,18 +79,11 @@ fn acquire_lsn_lease_with_retry(
let configs = {
let state = compute.state.lock().unwrap();
let spec = state.pspec.as_ref().expect("spec must be set");
let spec = state.spec.as_ref().expect("spec must be set");
let conn_strings = spec.pageserver_connstr.split(',');
conn_strings
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
config.password(storage_auth_token.clone());
}
config
})
spec.pageservers
.iter()
.map(|p| postgres::Config::from(p))
.collect::<Vec<_>>()
};
@@ -105,7 +98,7 @@ fn acquire_lsn_lease_with_retry(
Err(e) => {
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
compute.wait_timeout_while_pageservers_unchanged(Duration::from_millis(
retry_period_ms as u64,
));
retry_period_ms *= 1.5;

View File

@@ -4,7 +4,7 @@ use std::future::Future;
use std::iter::{empty, once};
use std::sync::Arc;
use anyhow::{Context, Result};
use anyhow::Result;
use compute_api::responses::ComputeStatus;
use compute_api::spec::{ComputeAudit, ComputeSpec, Database, PgIdent, Role};
use futures::future::join_all;
@@ -74,7 +74,7 @@ impl ComputeNode {
let mut drop_subscriptions_done = false;
if spec.drop_subscriptions_before_start {
let timeline_id = self.get_timeline_id().context("timeline_id must be set")?;
let timeline_id = self.get_timeline_id();
info!("Checking if drop subscription operation was already performed for timeline_id: {}", timeline_id);

View File

@@ -37,7 +37,7 @@ pub async fn ping_safekeeper(
// Parse result
info!("done with {}", id);
if let postgres::SimpleQueryMessage::Row(row) = &result[0] {
if let tokio_postgres::SimpleQueryMessage::Row(row) = &result[0] {
use std::str::FromStr;
let response = TimelineStatusResponse::Ok(TimelineStatusOkResponse {
flush_lsn: Lsn::from_str(row.get("flush_lsn").unwrap())?,

View File

@@ -1493,7 +1493,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config");
(
vec![(parsed.0, parsed.1.unwrap_or(5432))],
vec![compute_api::spec::Pageserver {
host: parsed.0,
port: parsed.1.unwrap_or(5432),
}],
// If caller is telling us what pageserver to use, this is not a tenant which is
// full managed by storage controller, therefore not sharded.
DEFAULT_STRIPE_SIZE,
@@ -1516,11 +1519,11 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.await?;
}
anyhow::Ok((
Host::parse(&shard.listen_pg_addr)
anyhow::Ok(compute_api::spec::Pageserver {
host: Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
))
port: shard.listen_pg_port,
})
}),
)
.await?;
@@ -1576,10 +1579,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?);
vec![(
pageserver.pg_connection_config.host().clone(),
pageserver.pg_connection_config.port(),
)]
vec![compute_api::spec::Pageserver {
host: pageserver.pg_connection_config.host().clone(),
port: pageserver.pg_connection_config.port(),
}]
} else {
let storage_controller = StorageController::from_env(env);
storage_controller
@@ -1587,12 +1590,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.await?
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
shard.listen_pg_port,
)
.map(|shard| compute_api::spec::Pageserver {
host: Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
port: shard.listen_pg_port,
})
.collect::<Vec<_>>()
};

View File

@@ -45,8 +45,6 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use compute_api::requests::{
COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest,
};
@@ -54,8 +52,8 @@ use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent,
RemoteExtSpec, Role,
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, Pageserver, PgIdent,
RemoteExtSpec, Role, Safekeeper,
};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
@@ -166,7 +164,7 @@ impl ComputeControlPlane {
public_key_use: Some(PublicKeyUse::Signature),
key_operations: Some(vec![KeyOperations::Verify]),
key_algorithm: Some(KeyAlgorithm::EdDSA),
key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)),
key_id: Some(base64::encode_config(key_hash, base64::URL_SAFE_NO_PAD)),
x509_url: None::<String>,
x509_chain: None::<Vec<String>>,
x509_sha1_fingerprint: None::<String>,
@@ -175,7 +173,7 @@ impl ComputeControlPlane {
algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters {
key_type: OctetKeyPairType::OctetKeyPair,
curve: EllipticCurve::Ed25519,
x: BASE64_URL_SAFE_NO_PAD.encode(public_key),
x: base64::encode_config(public_key, base64::URL_SAFE_NO_PAD),
}),
}],
})
@@ -608,29 +606,25 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String {
pageservers
.iter()
.map(|(host, port)| format!("postgresql://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
fn safekeepers_from_nodes(&self, ids: Vec<NodeId>) -> Result<Vec<Safekeeper>> {
let mut s = Vec::new();
/// Map safekeepers ids to the actual connection strings.
fn build_safekeepers_connstrs(&self, sk_ids: Vec<NodeId>) -> Result<Vec<String>> {
let mut safekeeper_connstrings = Vec::new();
if self.mode == ComputeMode::Primary {
for sk_id in sk_ids {
for id in ids {
let sk = self
.env
.safekeepers
.iter()
.find(|node| node.id == sk_id)
.ok_or_else(|| anyhow!("safekeeper {sk_id} does not exist"))?;
safekeeper_connstrings.push(format!("127.0.0.1:{}", sk.get_compute_port()));
.find(|node| node.id == id)
.ok_or_else(|| anyhow!("safekeeper {id} does not exist"))?;
s.push(Safekeeper {
host: Host::parse("127.0.0.1")?,
port: sk.get_compute_port(),
});
}
}
Ok(safekeeper_connstrings)
Ok(s)
}
/// Generate a JWT with the correct claims.
@@ -656,7 +650,7 @@ impl Endpoint {
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(Host, u16)>,
pageservers: Vec<Pageserver>,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
@@ -674,11 +668,6 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstring.is_empty());
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
// check for file remote_extensions_spec.json
// if it is present, read it and pass to compute_ctl
let remote_extensions_spec_path = self.endpoint_path().join("remote_extensions_spec.json");
@@ -729,15 +718,34 @@ impl Endpoint {
postgresql_conf: Some(postgresql_conf.clone()),
},
delta_operations: None,
tenant_id: Some(self.tenant_id),
timeline_id: Some(self.timeline_id),
project_id: None,
branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()),
tenant_id: self.tenant_id.clone(),
timeline_id: self.timeline_id.clone(),
project_id: self.tenant_id.to_string(),
branch_id: self.timeline_id.to_string(),
endpoint_id: self.endpoint_id.clone(),
mode: self.mode,
pageserver_connstring: Some(pageserver_connstring),
pageservers,
safekeepers: {
let mut s = Vec::new();
if self.mode == ComputeMode::Primary {
for id in safekeepers {
let sk = self
.env
.safekeepers
.iter()
.find(|node| node.id == id)
.ok_or_else(|| anyhow!("safekeeper {id} does not exist"))?;
s.push(Safekeeper {
host: Host::parse("127.0.0.1")?,
port: sk.get_compute_port(),
});
}
}
s
},
safekeepers_generation: safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
@@ -941,7 +949,7 @@ impl Endpoint {
pub async fn reconfigure(
&self,
mut pageservers: Vec<(Host, u16)>,
pageservers: Vec<Pageserver>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
) -> Result<()> {
@@ -960,30 +968,24 @@ impl Endpoint {
if pageservers.is_empty() {
let storage_controller = StorageController::from_env(&self.env);
let locate_result = storage_controller.tenant_locate(self.tenant_id).await?;
pageservers = locate_result
spec.pageservers = locate_result
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
)
.map(|shard| Pageserver {
host: Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
port: shard.listen_pg_port,
})
.collect::<Vec<_>>();
}
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstr.is_empty());
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
}
// If safekeepers are not specified, don't change them.
if let Some(safekeepers) = safekeepers {
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
spec.safekeeper_connstrings = safekeeper_connstrings;
spec.safekeepers = self.safekeepers_from_nodes(safekeepers)?;
}
let client = reqwest::Client::builder()

View File

@@ -9,8 +9,11 @@ anyhow.workspace = true
chrono.workspace = true
indexmap.workspace = true
jsonwebtoken.workspace = true
postgres.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio-postgres.workspace = true
url.workspace = true
regex.workspace = true
utils = { path = "../utils" }

View File

@@ -9,7 +9,8 @@ use indexmap::IndexMap;
use regex::Regex;
use remote_storage::RemotePath;
use serde::{Deserialize, Serialize};
use utils::id::{TenantId, TimelineId};
use url::Host;
use utils::id::{BranchId, EndpointId, ProjectId, TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::responses::TlsConfig;
@@ -21,13 +22,77 @@ pub type PgIdent = String;
/// String type alias representing Postgres extension version
pub type ExtVersion = String;
/// Pageserver settings.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Pageserver {
/// Hostname of the pageserver.
pub host: Host,
/// Port that the safekeeper listens on.
pub port: u16,
}
impl From<&Pageserver> for postgres::Config {
fn from(ps: &Pageserver) -> Self {
let mut config = postgres::Config::new();
config.host(&ps.host.to_string());
config.port(ps.port);
config
}
}
impl From<&Pageserver> for tokio_postgres::Config {
fn from(ps: &Pageserver) -> Self {
let mut config = tokio_postgres::Config::new();
config.host(&ps.host.to_string());
config.port(ps.port);
config
}
}
/// Safekeeper settings.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Safekeeper {
/// Hostname of the safekeeper.
pub host: Host,
/// Port that the safekeeper listens on.
pub port: u16,
}
impl From<&Safekeeper> for postgres::Config {
fn from(sk: &Safekeeper) -> Self {
let mut config = postgres::Config::new();
config.host(&sk.host.to_string());
config.port(sk.port);
config
}
}
impl From<&Safekeeper> for tokio_postgres::Config {
fn from(sk: &Safekeeper) -> Self {
let mut config = tokio_postgres::Config::new();
config.host(&sk.host.to_string());
config.port(sk.port);
config
}
}
fn default_reconfigure_concurrency() -> usize {
1
}
/// Cluster spec or configuration represented as an optional number of
/// delta operations + final cluster state description.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ComputeSpec {
pub format_version: f32,
@@ -90,25 +155,13 @@ pub struct ComputeSpec {
// Information needed to connect to the storage layer.
//
// `tenant_id`, `timeline_id` and `pageserver_connstring` are always needed.
//
// Depending on `mode`, this can be a primary read-write node, a read-only
// replica, or a read-only node pinned at an older LSN.
// `safekeeper_connstrings` must be set for a primary.
//
// For backwards compatibility, the control plane may leave out all of
// these, and instead set the "neon.tenant_id", "neon.timeline_id",
// etc. GUCs in cluster.settings. TODO: Once the control plane has been
// updated to fill these fields, we can make these non optional.
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
pub pageserver_connstring: Option<String>,
pub pageservers: Vec<Pageserver>,
// More neon ids that we expose to the compute_ctl
// and to postgres as neon extension GUCs.
pub project_id: Option<String>,
pub branch_id: Option<String>,
pub endpoint_id: Option<String>,
#[serde(default)]
pub safekeepers_generation: Option<u32>,
/// Safekeeper membership config generation. It is put in
/// neon.safekeepers GUC and serves two purposes:
@@ -120,9 +173,18 @@ pub struct ComputeSpec {
/// Note: it could be SafekeeperGeneration, but this needs linking
/// compute_ctl with postgres_ffi.
#[serde(default)]
pub safekeepers_generation: Option<u32>,
#[serde(default)]
pub safekeeper_connstrings: Vec<String>,
pub safekeepers: Vec<Safekeeper>,
/// The Neon tenant ID. Exposed to Postgres as `neon.tenant_id`.
pub tenant_id: TenantId,
/// The Neon timeline ID. Exposed to Postgres as `neon.timeline_id`.
pub timeline_id: TimelineId,
/// The Neon project ID. Exposed to Postgres as `neon.project_id`.
pub project_id: ProjectId,
/// The Neon branch ID. Exposed to Postgres as `neon.branch_id`.
pub branch_id: BranchId,
/// The Neon endpoint ID. Exposed to Postgres as `neon.endpoint_id`.
pub endpoint_id: EndpointId,
#[serde(default)]
pub mode: ComputeMode,

View File

@@ -9,7 +9,7 @@ use utils::id::{NodeId, TimelineId};
use crate::controller_api::NodeRegisterRequest;
use crate::models::{LocationConfigMode, ShardImportStatus};
use crate::shard::{ShardStripeSize, TenantShardId};
use crate::shard::TenantShardId;
/// Upcall message sent by the pageserver to the configured `control_plane_api` on
/// startup.
@@ -36,10 +36,6 @@ pub struct ReAttachResponseTenant {
/// Default value only for backward compat: this field should be set
#[serde(default = "default_mode")]
pub mode: LocationConfigMode,
// Default value only for backward compat: this field should be set
#[serde(default = "ShardStripeSize::default")]
pub stripe_size: ShardStripeSize,
}
#[derive(Serialize, Deserialize)]
pub struct ReAttachResponse {

View File

@@ -55,16 +55,9 @@ impl FeatureResolverBackgroundLoop {
continue;
}
};
let project_id = this.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec: {}", e);
}
}
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
tracing::info!("PostHog feature resolver stopped");
}

View File

@@ -39,9 +39,6 @@ pub struct LocalEvaluationResponse {
#[derive(Deserialize)]
pub struct LocalEvaluationFlag {
#[allow(dead_code)]
id: u64,
team_id: u64,
key: String,
filters: LocalEvaluationFlagFilters,
active: bool,
@@ -110,32 +107,17 @@ impl FeatureStore {
}
}
pub fn new_with_flags(
flags: Vec<LocalEvaluationFlag>,
project_id: Option<u64>,
) -> Result<Self, &'static str> {
pub fn new_with_flags(flags: Vec<LocalEvaluationFlag>) -> Self {
let mut store = Self::new();
store.set_flags(flags, project_id)?;
Ok(store)
store.set_flags(flags);
store
}
pub fn set_flags(
&mut self,
flags: Vec<LocalEvaluationFlag>,
project_id: Option<u64>,
) -> Result<(), &'static str> {
pub fn set_flags(&mut self, flags: Vec<LocalEvaluationFlag>) {
self.flags.clear();
for flag in flags {
if let Some(project_id) = project_id {
if flag.team_id != project_id {
return Err(
"Retrieved a spec with different project id, wrong config? Discarding the feature flags.",
);
}
}
self.flags.insert(flag.key.clone(), flag);
}
Ok(())
}
/// Generate a consistent hash for a user ID (e.g., tenant ID).
@@ -552,13 +534,6 @@ impl PostHogClient {
})
}
/// Check if the server API key is a feature flag secure API key. This key can only be
/// used to fetch the feature flag specs and can only be used on a undocumented API
/// endpoint.
fn is_feature_flag_secure_api_key(&self) -> bool {
self.config.server_api_key.starts_with("phs_")
}
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
@@ -572,22 +547,10 @@ impl PostHogClient {
) -> anyhow::Result<LocalEvaluationResponse> {
// BASE_URL/api/projects/:project_id/feature_flags/local_evaluation
// with bearer token of self.server_api_key
// OR
// BASE_URL/api/feature_flag/local_evaluation/
// with bearer token of feature flag specific self.server_api_key
let url = if self.is_feature_flag_secure_api_key() {
// The new feature local evaluation secure API token
format!(
"{}/api/feature_flag/local_evaluation",
self.config.private_api_url
)
} else {
// The old personal API token
format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
)
};
let url = format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
);
let response = self
.client
.get(url)
@@ -840,7 +803,7 @@ mod tests {
fn evaluate_multivariate() {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =
@@ -910,7 +873,7 @@ mod tests {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
@@ -966,7 +929,7 @@ mod tests {
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags, None).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =

View File

@@ -5,7 +5,7 @@ edition = "2024"
license = "MIT/Apache-2.0"
[dependencies]
base64.workspace = true
base64 = "0.20"
byteorder.workspace = true
bytes.workspace = true
fallible-iterator.workspace = true

View File

@@ -3,8 +3,6 @@
use std::fmt::Write;
use std::{io, iter, mem, str};
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
@@ -228,7 +226,7 @@ impl ScramSha256 {
let (client_key, server_key) = match password {
Credentials::Password(password) => {
let salt = match BASE64_STANDARD.decode(parsed.salt) {
let salt = match base64::decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
@@ -257,7 +255,7 @@ impl ScramSha256 {
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = BASE64_STANDARD.encode(&cbind_input);
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
@@ -274,12 +272,7 @@ impl ScramSha256 {
*proof ^= signature;
}
write!(
&mut self.message,
",p={}",
BASE64_STANDARD.encode(client_proof)
)
.unwrap();
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
self.state = State::Finish {
server_key,
@@ -313,7 +306,7 @@ impl ScramSha256 {
ServerFinalMessage::Verifier(verifier) => verifier,
};
let verifier = match BASE64_STANDARD.decode(verifier) {
let verifier = match base64::decode(verifier) {
Ok(verifier) => verifier,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};

View File

@@ -6,8 +6,6 @@
//! side. This is good because it ensures the cleartext password won't
//! end up in logs pg_stat displays, etc.
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::digest::FixedOutput;
@@ -85,8 +83,8 @@ pub(crate) async fn scram_sha_256_salt(
format!(
"SCRAM-SHA-256${}:{}${}:{}",
SCRAM_DEFAULT_ITERATIONS,
BASE64_STANDARD.encode(salt),
BASE64_STANDARD.encode(stored_key),
BASE64_STANDARD.encode(server_key)
base64::encode(salt),
base64::encode(stored_key),
base64::encode(server_key)
)
}

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{env, io};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, Result};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
@@ -37,7 +37,6 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
Version, VersionKind,
};
pub struct AzureBlobStorage {
@@ -406,39 +405,6 @@ impl AzureBlobStorage {
pub fn container_name(&self) -> &str {
&self.container_name
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
// We do not return this info back to `VersionListing` yet.
builder = builder.include_deleted(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
}
trait ListingCollector {
@@ -522,10 +488,27 @@ impl RemoteStorage for AzureBlobStorage {
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
async fn head_object(
@@ -820,159 +803,14 @@ impl RemoteStorage for AzureBlobStorage {
async fn time_travel_recover(
&self,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
_complexity_limit: Option<NonZeroU32>,
_prefix: Option<&RemotePath>,
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
.to_string()
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
+ "to the time during the file was not present, this functionality will recover the file. Only "
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
+ "pageserver tenants.";
tracing::error!("{}", msg);
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
versions_and_deletes.len()
);
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_name(key);
if last_vd.last_modified > done_if_after {
tracing::debug!("Key {key} has version later than done_if_after, skipping");
continue;
}
// the version we want to restore to.
let version_to_restore_to =
match versions.binary_search_by_key(&timestamp, |tpl| tpl.last_modified) {
Ok(v) => v,
Err(e) => e,
};
if version_to_restore_to == versions.len() {
tracing::debug!("Key {key} has no changes since timestamp, skipping");
continue;
}
let mut do_delete = false;
if version_to_restore_to == 0 {
// All versions more recent, so the key didn't exist at the specified time point.
tracing::debug!(
"All {} versions more recent for {key}, deleting",
versions.len()
);
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
Version {
kind: VersionKind::Version(version_id),
..
} => {
let source_url = format!(
"{}/{}?versionid={}",
self.client
.url()
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
key,
version_id.0
);
tracing::debug!(
"Promoting old version {} for {key} at {}...",
version_id.0,
source_url
);
backoff::retry(
|| async {
let blob_client = self.client.blob_client(key.clone());
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"copying object version for time_travel_recover",
cancel,
)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
}
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
}
}
};
if do_delete {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::debug!("Key {key} already deleted, skipping.");
} else {
tracing::debug!("Deleting {key}...");
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
.await
.map_err(|e| {
// delete_oid0 will use TimeoutOrCancel
if TimeoutOrCancel::caused_by_cancel(&e) {
TimeTravelError::Cancelled
} else {
TimeTravelError::Other(e)
}
})?;
}
}
}
Ok(())
// TODO use Azure point in time recovery feature for this
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
Err(TimeTravelError::Unimplemented)
}
}

View File

@@ -440,7 +440,6 @@ pub trait RemoteStorage: Send + Sync + 'static {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError>;
}
@@ -652,23 +651,22 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
match self {
Self::LocalFs(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::AwsS3(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::AzureBlob(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
Self::Unreliable(s) => {
s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
s.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
}

View File

@@ -610,7 +610,6 @@ impl RemoteStorage for LocalFs {
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
_complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
Err(TimeTravelError::Unimplemented)
}

View File

@@ -981,16 +981,22 @@ impl RemoteStorage for S3Bucket {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, complexity_limit, cancel)
.list_versions_with_permit(&permit, prefix, mode, COMPLEXITY_LIMIT, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
@@ -1016,7 +1022,6 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -240,12 +240,11 @@ impl RemoteStorage for UnreliableWrapper {
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
complexity_limit: Option<NonZeroU32>,
) -> Result<(), TimeTravelError> {
self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned())))
.map_err(TimeTravelError::Other)?;
self.inner
.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit)
.time_travel_recover(prefix, timestamp, done_if_after, cancel)
.await
}
}

View File

@@ -157,7 +157,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// No changes after recovery to t2 (no-op)
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t2, t_final, &cancel, None)
.time_travel_recover(None, t2, t_final, &cancel)
.await?;
let t2_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t2: {t2_files_recovered:?}");
@@ -173,7 +173,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// after recovery to t1: path1 is back, path2 has the old content
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t1, t_final, &cancel, None)
.time_travel_recover(None, t1, t_final, &cancel)
.await?;
let t1_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t1: {t1_files_recovered:?}");
@@ -189,7 +189,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
// after recovery to t0: everything is gone except for path1
let t_final = time_point().await;
ctx.client
.time_travel_recover(None, t0, t_final, &cancel, None)
.time_travel_recover(None, t0, t_final, &cancel)
.await?;
let t0_files_recovered = list_files(&ctx.client, &cancel).await?;
println!("after recovery to t0: {t0_files_recovered:?}");

View File

@@ -295,7 +295,11 @@ pub struct TenantId(Id);
id_newtype!(TenantId);
/// If needed, reuse small string from proxy/src/types.rc
/// Type representing a project ID.
pub type ProjectId = String;
/// Type representing a branch ID.
pub type BranchId = String;
/// Type representing an endpoint ID.
pub type EndpointId = String;
// A pair uniquely identifying Neon instance.

View File

@@ -176,11 +176,9 @@ async fn main() -> anyhow::Result<()> {
let config = RemoteStorageConfig::from_toml_str(&cmd.config_toml_str)?;
let storage = remote_storage::GenericRemoteStorage::from_config(&config).await;
let cancel = CancellationToken::new();
// Complexity limit: as we are running this command locally, we should have a lot of memory available, and we do not
// need to limit the number of versions we are going to delete.
storage
.unwrap()
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel, None)
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel)
.await?;
}
Commands::Key(dkc) => dkc.execute(),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc};
use anyhow::Context;
use async_compression::tokio::write::GzipEncoder;
use camino::{Utf8Path, Utf8PathBuf};
use metrics::core::{AtomicU64, GenericCounter};
@@ -168,17 +167,14 @@ impl BasebackupCache {
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
fn tmp_dir(&self) -> Utf8PathBuf {
self.data_dir.join("tmp")
}
fn entry_tmp_path(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Utf8PathBuf {
self.tmp_dir()
self.data_dir
.join("tmp")
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
@@ -198,18 +194,15 @@ impl BasebackupCache {
Some((tenant_id, timeline_id, lsn))
}
// Recreate the tmp directory to clear all files in it.
async fn clean_tmp_dir(&self) -> anyhow::Result<()> {
let tmp_dir = self.tmp_dir();
if tmp_dir.exists() {
tokio::fs::remove_dir_all(&tmp_dir).await?;
}
tokio::fs::create_dir_all(&tmp_dir).await?;
Ok(())
}
async fn cleanup(&self) -> anyhow::Result<()> {
self.clean_tmp_dir().await?;
// Cleanup tmp directory.
let tmp_dir = self.data_dir.join("tmp");
let mut tmp_dir = tokio::fs::read_dir(&tmp_dir).await?;
while let Some(dir_entry) = tmp_dir.next_entry().await? {
if let Err(e) = tokio::fs::remove_file(dir_entry.path()).await {
tracing::warn!("Failed to remove basebackup cache tmp file: {:#}", e);
}
}
// Remove outdated entries.
let entries_old = self.entries.lock().unwrap().clone();
@@ -248,14 +241,16 @@ impl BasebackupCache {
}
async fn on_startup(&self) -> anyhow::Result<()> {
// Create data_dir if it does not exist.
tokio::fs::create_dir_all(&self.data_dir)
// Create data_dir and tmp directory if they do not exist.
tokio::fs::create_dir_all(&self.data_dir.join("tmp"))
.await
.context("Failed to create basebackup cache data directory")?;
self.clean_tmp_dir()
.await
.context("Failed to clean tmp directory")?;
.map_err(|e| {
anyhow::anyhow!(
"Failed to create basebackup cache data_dir {:?}: {:?}",
self.data_dir,
e
)
})?;
// Read existing entries from the data_dir and add them to in-memory state.
let mut entries = HashMap::new();
@@ -413,19 +408,6 @@ impl BasebackupCache {
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let feature_flag = tenant
.feature_resolver
.evaluate_boolean("enable-basebackup-cache", tenant_shard_id.tenant_id);
if feature_flag.is_err() {
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
"Basebackup cache is disabled for tenant by feature flag, skipping basebackup",
);
self.prepare_skip_count.inc();
return Ok(());
}
let tenant_state = tenant.current_state();
if tenant_state != TenantState::Active {
anyhow::bail!(
@@ -469,11 +451,6 @@ impl BasebackupCache {
}
// Move the tmp file to the final location atomically.
// The tmp file is fsynced, so it's guaranteed that we will not have a partial file
// in the main directory.
// It's not necessary to fsync the inode after renaming, because the worst case is that
// the rename operation will be rolled back on the disk failure, the entry will disappear
// from the main directory, and the entry access will cause a cache miss.
let entry_path = self.entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn);
tokio::fs::rename(&entry_tmp_path, &entry_path).await?;
@@ -491,17 +468,16 @@ impl BasebackupCache {
}
/// Prepares a basebackup in a temporary file.
/// Guarantees that the tmp file is fsynced before returning.
async fn prepare_basebackup_tmp(
&self,
entry_tmp_path: &Utf8Path,
emptry_tmp_path: &Utf8Path,
timeline: &Arc<Timeline>,
req_lsn: Lsn,
) -> anyhow::Result<()> {
let ctx = RequestContext::new(TaskKind::BasebackupCache, DownloadBehavior::Download);
let ctx = ctx.with_scope_timeline(timeline);
let file = tokio::fs::File::create(entry_tmp_path).await?;
let file = tokio::fs::File::create(emptry_tmp_path).await?;
let mut writer = BufWriter::new(file);
let mut encoder = GzipEncoder::with_quality(

View File

@@ -573,8 +573,7 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = mgr::init(
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -585,10 +584,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
);
))?;
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -87,35 +86,7 @@ impl FeatureResolver {
}
}
}
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
// TODO: add pageserver URL.
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -73,7 +73,6 @@ use crate::tenant::remote_timeline_client::{
use crate::tenant::secondary::SecondaryController;
use crate::tenant::size::ModelInputs;
use crate::tenant::storage_layer::{IoConcurrency, LayerAccessStatsReset, LayerName};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::timeline::offload::{OffloadError, offload_timeline};
use crate::tenant::timeline::{
CompactFlags, CompactOptions, CompactRequest, CompactionError, MarkInvisibleRequest, Timeline,
@@ -1452,10 +1451,7 @@ async fn timeline_layer_scan_disposable_keys(
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download)
.with_scope_timeline(&timeline);
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let Some(layer) = guard.try_get_from_key(&layer_name.clone().into()) else {
return Err(ApiError::NotFound(
anyhow::anyhow!("Layer {tenant_shard_id}/{timeline_id}/{layer_name} not found").into(),

View File

@@ -51,7 +51,6 @@ use secondary::heatmap::{HeatMapTenant, HeatMapTimeline};
use storage_broker::BrokerClientChannel;
use timeline::compaction::{CompactionOutcome, GcCompactionQueue};
use timeline::import_pgdata::ImportingTimeline;
use timeline::layer_manager::LayerManagerLockHolder;
use timeline::offload::{OffloadError, offload_timeline};
use timeline::{
CompactFlags, CompactOptions, CompactionError, PreviousHeatmap, ShutdownMode, import_pgdata,
@@ -1316,7 +1315,7 @@ impl TenantShard {
ancestor.is_some()
|| timeline
.layers
.read(LayerManagerLockHolder::LoadLayerMap)
.read()
.await
.layer_map()
.expect(
@@ -2644,7 +2643,7 @@ impl TenantShard {
}
let layer_names = tline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.layer_map()
.unwrap()
@@ -3159,12 +3158,7 @@ impl TenantShard {
for timeline in &compact {
// Collect L0 counts. Can't await while holding lock above.
if let Ok(lm) = timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await
.layer_map()
{
if let Ok(lm) = timeline.layers.read().await.layer_map() {
l0_counts.insert(timeline.timeline_id, lm.level0_deltas().len());
}
}
@@ -4906,7 +4900,7 @@ impl TenantShard {
}
let layer_names = tline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.layer_map()
.unwrap()
@@ -6976,7 +6970,7 @@ mod tests {
.await?;
make_some_layers(tline.as_ref(), Lsn(0x20), &ctx).await?;
let layer_map = tline.layers.read(LayerManagerLockHolder::Testing).await;
let layer_map = tline.layers.read().await;
let level0_deltas = layer_map
.layer_map()?
.level0_deltas()
@@ -7212,7 +7206,7 @@ mod tests {
let lsn = Lsn(0x10);
let inserted = bulk_insert_compact_gc(&tenant, &tline, &ctx, lsn, 50, 10000).await?;
let guard = tline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = tline.layers.read().await;
let lm = guard.layer_map()?;
lm.dump(true, &ctx).await?;
@@ -8240,23 +8234,12 @@ mod tests {
tline.freeze_and_flush().await?; // force create a delta layer
}
let before_num_l0_delta_files = tline
.layers
.read(LayerManagerLockHolder::Testing)
.await
.layer_map()?
.level0_deltas()
.len();
let before_num_l0_delta_files =
tline.layers.read().await.layer_map()?.level0_deltas().len();
tline.compact(&cancel, EnumSet::default(), &ctx).await?;
let after_num_l0_delta_files = tline
.layers
.read(LayerManagerLockHolder::Testing)
.await
.layer_map()?
.level0_deltas()
.len();
let after_num_l0_delta_files = tline.layers.read().await.layer_map()?.level0_deltas().len();
assert!(
after_num_l0_delta_files < before_num_l0_delta_files,

View File

@@ -61,8 +61,8 @@ pub(crate) struct LocationConf {
/// The detailed shard identity. This structure is already scoped within
/// a TenantShardId, but we need the full ShardIdentity to enable calculating
/// key->shard mappings.
// TODO(vlad): Remove this default once all configs have a shard identity on disk.
#[serde(default = "ShardIdentity::unsharded")]
#[serde(skip_serializing_if = "ShardIdentity::is_unsharded")]
pub(crate) shard: ShardIdentity,
/// The pan-cluster tenant configuration, the same on all locations
@@ -149,12 +149,7 @@ impl LocationConf {
/// For use when attaching/re-attaching: update the generation stored in this
/// structure. If we were in a secondary state, promote to attached (posession
/// of a fresh generation implies this).
pub(crate) fn attach_in_generation(
&mut self,
mode: AttachmentMode,
generation: Generation,
stripe_size: ShardStripeSize,
) {
pub(crate) fn attach_in_generation(&mut self, mode: AttachmentMode, generation: Generation) {
match &mut self.mode {
LocationMode::Attached(attach_conf) => {
attach_conf.generation = generation;
@@ -168,8 +163,6 @@ impl LocationConf {
})
}
}
self.shard.stripe_size = stripe_size;
}
pub(crate) fn try_from(conf: &'_ models::LocationConfig) -> anyhow::Result<Self> {

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
//! Helper functions to upload files to remote storage with a RemoteStorage
use std::io::{ErrorKind, SeekFrom};
use std::num::NonZeroU32;
use std::time::SystemTime;
use anyhow::{Context, bail};
@@ -229,25 +228,11 @@ pub(crate) async fn time_travel_recover_tenant(
let timelines_path = super::remote_timelines_path(tenant_shard_id);
prefixes.push(timelines_path);
}
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
for prefix in &prefixes {
backoff::retry(
|| async {
storage
.time_travel_recover(
Some(prefix),
timestamp,
done_if_after,
cancel,
COMPLEXITY_LIMIT,
)
.time_travel_recover(Some(prefix), timestamp, done_if_after, cancel)
.await
},
|e| !matches!(e, TimeTravelError::Other(_)),

View File

@@ -1635,7 +1635,6 @@ pub(crate) mod test {
use crate::tenant::disk_btree::tests::TestDisk;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
use crate::tenant::storage_layer::{Layer, ResidentLayer};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::{TenantShard, Timeline};
/// Construct an index for a fictional delta layer and and then
@@ -2003,7 +2002,7 @@ pub(crate) mod test {
let initdb_layer = timeline
.layers
.read(crate::tenant::timeline::layer_manager::LayerManagerLockHolder::Testing)
.read()
.await
.likely_resident_layers()
.next()
@@ -2079,7 +2078,7 @@ pub(crate) mod test {
let new_layer = timeline
.layers
.read(LayerManagerLockHolder::Testing)
.read()
.await
.likely_resident_layers()
.find(|&x| x != &initdb_layer)

View File

@@ -10,7 +10,6 @@ use super::*;
use crate::context::DownloadBehavior;
use crate::tenant::harness::{TenantHarness, test_img};
use crate::tenant::storage_layer::{IoConcurrency, LayerVisibilityHint};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
/// Used in tests to advance a future to wanted await point, and not futher.
const ADVANCE: std::time::Duration = std::time::Duration::from_secs(3600);
@@ -60,7 +59,7 @@ async fn smoke_test() {
// there to avoid the timeline being illegally empty
let (layer, dummy_layer) = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -216,7 +215,7 @@ async fn smoke_test() {
// Simulate GC removing our test layer.
{
let mut g = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut g = timeline.layers.write().await;
let layers = &[layer];
g.open_mut().unwrap().finish_gc_timeline(layers);
@@ -262,7 +261,7 @@ async fn evict_and_wait_on_wanted_deleted() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -306,7 +305,7 @@ async fn evict_and_wait_on_wanted_deleted() {
// assert that once we remove the `layer` from the layer map and drop our reference,
// the deletion of the layer in remote_storage happens.
{
let mut layers = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut layers = timeline.layers.write().await;
layers.open_mut().unwrap().finish_gc_timeline(&[layer]);
}
@@ -348,7 +347,7 @@ fn read_wins_pending_eviction() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -481,7 +480,7 @@ fn multiple_pending_evictions_scenario(name: &'static str, in_order: bool) {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -656,7 +655,7 @@ async fn cancelled_get_or_maybe_download_does_not_cancel_eviction() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -742,7 +741,7 @@ async fn evict_and_wait_does_not_wait_for_download() {
let layer = {
let mut layers = {
let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let layers = timeline.layers.read().await;
layers.likely_resident_layers().cloned().collect::<Vec<_>>()
};
@@ -863,7 +862,7 @@ async fn eviction_cancellation_on_drop() {
let (evicted_layer, not_evicted) = {
let mut layers = {
let mut guard = timeline.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = timeline.layers.write().await;
let layers = guard.likely_resident_layers().cloned().collect::<Vec<_>>();
// remove the layers from layermap
guard.open_mut().unwrap().finish_gc_timeline(&layers);

View File

@@ -35,11 +35,7 @@ use fail::fail_point;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use handle::ShardTimelineId;
use layer_manager::{
LayerManagerLockHolder, LayerManagerReadGuard, LayerManagerWriteGuard, LockedLayerManager,
Shutdown,
};
use layer_manager::Shutdown;
use offload::OffloadError;
use once_cell::sync::Lazy;
use pageserver_api::config::tenant_conf_defaults::DEFAULT_PITR_INTERVAL;
@@ -86,6 +82,7 @@ use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
use self::delete::DeleteTimelineFlow;
pub(super) use self::eviction_task::EvictionTaskTenantState;
use self::eviction_task::EvictionTaskTimelineState;
use self::layer_manager::LayerManager;
use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::remote_timeline_client::RemoteTimelineClient;
@@ -184,13 +181,13 @@ impl std::fmt::Display for ImageLayerCreationMode {
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
/// Can be removed after all refactors are done.
fn drop_layer_manager_rlock(rlock: LayerManagerReadGuard<'_>) {
fn drop_rlock<T>(rlock: tokio::sync::RwLockReadGuard<T>) {
drop(rlock)
}
/// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things.
/// Can be removed after all refactors are done.
fn drop_layer_manager_wlock(rlock: LayerManagerWriteGuard<'_>) {
fn drop_wlock<T>(rlock: tokio::sync::RwLockWriteGuard<'_, T>) {
drop(rlock)
}
@@ -244,7 +241,7 @@ pub struct Timeline {
///
/// In the future, we'll be able to split up the tuple of LayerMap and `LayerFileManager`,
/// so that e.g. on-demand-download/eviction, and layer spreading, can operate just on `LayerFileManager`.
pub(crate) layers: LockedLayerManager,
pub(crate) layers: tokio::sync::RwLock<LayerManager>,
last_freeze_at: AtomicLsn,
// Atomic would be more appropriate here.
@@ -1538,10 +1535,7 @@ impl Timeline {
/// This method makes no distinction between local and remote layers.
/// Hence, the result **does not represent local filesystem usage**.
pub(crate) async fn layer_size_sum(&self) -> u64 {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
guard.layer_size_sum()
}
@@ -1851,7 +1845,7 @@ impl Timeline {
// time, and this was missed.
// if write_guard.is_none() { return; }
let Ok(layers_guard) = self.layers.try_read(LayerManagerLockHolder::TryFreezeLayer) else {
let Ok(layers_guard) = self.layers.try_read() else {
// Don't block if the layer lock is busy
return;
};
@@ -2164,7 +2158,7 @@ impl Timeline {
if let ShutdownMode::FreezeAndFlush = mode {
let do_flush = if let Some((open, frozen)) = self
.layers
.read(LayerManagerLockHolder::Shutdown)
.read()
.await
.layer_map()
.map(|lm| (lm.open_layer.is_some(), lm.frozen_layers.len()))
@@ -2268,10 +2262,7 @@ impl Timeline {
// Allow any remaining in-memory layers to do cleanup -- until that, they hold the gate
// open.
let mut write_guard = self.write_lock.lock().await;
self.layers
.write(LayerManagerLockHolder::Shutdown)
.await
.shutdown(&mut write_guard);
self.layers.write().await.shutdown(&mut write_guard);
}
// Finally wait until any gate-holders are complete.
@@ -2374,10 +2365,7 @@ impl Timeline {
&self,
reset: LayerAccessStatsReset,
) -> Result<LayerMapInfo, layer_manager::Shutdown> {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
let mut in_memory_layers = Vec::with_capacity(layer_map.frozen_layers.len() + 1);
if let Some(open_layer) = &layer_map.open_layer {
@@ -3244,7 +3232,7 @@ impl Timeline {
/// Initialize with an empty layer map. Used when creating a new timeline.
pub(super) fn init_empty_layer_map(&self, start_lsn: Lsn) {
let mut layers = self.layers.try_write(LayerManagerLockHolder::Init).expect(
let mut layers = self.layers.try_write().expect(
"in the context where we call this function, no other task has access to the object",
);
layers
@@ -3264,10 +3252,7 @@ impl Timeline {
use init::Decision::*;
use init::{Discovered, DismissedLayer};
let mut guard = self
.layers
.write(LayerManagerLockHolder::LoadLayerMap)
.await;
let mut guard = self.layers.write().await;
let timer = self.metrics.load_layer_map_histo.start_timer();
@@ -3884,10 +3869,7 @@ impl Timeline {
&self,
layer_name: &LayerName,
) -> Result<Option<Layer>, layer_manager::Shutdown> {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let layer = guard
.layer_map()?
.iter_historic_layers()
@@ -3920,10 +3902,7 @@ impl Timeline {
return None;
}
let guard = self
.layers
.read(LayerManagerLockHolder::GenerateHeatmap)
.await;
let guard = self.layers.read().await;
// Firstly, if there's any heatmap left over from when this location
// was a secondary, take that into account. Keep layers that are:
@@ -4021,10 +4000,7 @@ impl Timeline {
}
pub(super) async fn generate_unarchival_heatmap(&self, end_lsn: Lsn) -> PreviousHeatmap {
let guard = self
.layers
.read(LayerManagerLockHolder::GenerateHeatmap)
.await;
let guard = self.layers.read().await;
let now = SystemTime::now();
let mut heatmap_layers = Vec::default();
@@ -4366,7 +4342,7 @@ impl Timeline {
query: &VersionedKeySpaceQuery,
) -> Result<LayerFringe, GetVectoredError> {
let mut fringe = LayerFringe::new();
let guard = self.layers.read(LayerManagerLockHolder::GetPage).await;
let guard = self.layers.read().await;
match query {
VersionedKeySpaceQuery::Uniform { keyspace, lsn } => {
@@ -4469,7 +4445,7 @@ impl Timeline {
// required for correctness, but avoids visiting extra layers
// which turns out to be a perf bottleneck in some cases.
if !unmapped_keyspace.is_empty() {
let guard = timeline.layers.read(LayerManagerLockHolder::GetPage).await;
let guard = timeline.layers.read().await;
guard.update_search_fringe(&unmapped_keyspace, cont_lsn, &mut fringe)?;
// It's safe to drop the layer map lock after planning the next round of reads.
@@ -4579,10 +4555,7 @@ impl Timeline {
_guard: &tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
ctx: &RequestContext,
) -> anyhow::Result<Arc<InMemoryLayer>> {
let mut guard = self
.layers
.write(LayerManagerLockHolder::GetLayerForWrite)
.await;
let mut guard = self.layers.write().await;
let last_record_lsn = self.get_last_record_lsn();
ensure!(
@@ -4624,10 +4597,7 @@ impl Timeline {
write_lock: &mut tokio::sync::MutexGuard<'_, Option<TimelineWriterState>>,
) -> Result<u64, FlushLayerError> {
let frozen = {
let mut guard = self
.layers
.write(LayerManagerLockHolder::TryFreezeLayer)
.await;
let mut guard = self.layers.write().await;
guard
.open_mut()?
.try_freeze_in_memory_layer(at, &self.last_freeze_at, write_lock, &self.metrics)
@@ -4668,12 +4638,7 @@ impl Timeline {
ctx: &RequestContext,
) {
// Subscribe to L0 delta layer updates, for compaction backpressure.
let mut watch_l0 = match self
.layers
.read(LayerManagerLockHolder::FlushLoop)
.await
.layer_map()
{
let mut watch_l0 = match self.layers.read().await.layer_map() {
Ok(lm) => lm.watch_level0_deltas(),
Err(Shutdown) => return,
};
@@ -4710,7 +4675,7 @@ impl Timeline {
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let layers = self.layers.read().await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
return;
@@ -5006,10 +4971,7 @@ impl Timeline {
// in-memory layer from the map now. The flushed layer is stored in
// the mapping in `create_delta_layer`.
{
let mut guard = self
.layers
.write(LayerManagerLockHolder::FlushFrozenLayer)
.await;
let mut guard = self.layers.write().await;
guard.open_mut()?.finish_flush_l0_layer(
delta_layer_to_add.as_ref(),
@@ -5224,7 +5186,7 @@ impl Timeline {
async fn time_for_new_image_layer(&self, partition: &KeySpace, lsn: Lsn) -> bool {
let threshold = self.get_image_creation_threshold();
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let Ok(layers) = guard.layer_map() else {
return false;
};
@@ -5642,7 +5604,7 @@ impl Timeline {
if let ImageLayerCreationMode::Force = mode {
// When forced to create image layers, we might try and create them where they already
// exist. This mode is only used in tests/debug.
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
let layers = self.layers.read().await;
if layers.contains_key(&PersistentLayerKey {
key_range: img_range.clone(),
lsn_range: PersistentLayerDesc::image_layer_lsn_range(lsn),
@@ -5767,7 +5729,7 @@ impl Timeline {
let image_layers = batch_image_writer.finish(self, ctx).await?;
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
let mut guard = self.layers.write().await;
// FIXME: we could add the images to be uploaded *before* returning from here, but right
// now they are being scheduled outside of write lock; current way is inconsistent with
@@ -5775,7 +5737,7 @@ impl Timeline {
guard
.open_mut()?
.track_new_image_layers(&image_layers, &self.metrics);
drop_layer_manager_wlock(guard);
drop_wlock(guard);
let duration = timer.stop_and_record();
// Creating image layers may have caused some previously visible layers to be covered
@@ -6145,7 +6107,7 @@ impl Timeline {
layers_to_remove: &[Layer],
) -> Result<(), CompactionError> {
let mut guard = tokio::select! {
guard = self.layers.write(LayerManagerLockHolder::Compaction) => guard,
guard = self.layers.write() => guard,
_ = self.cancel.cancelled() => {
return Err(CompactionError::ShuttingDown);
}
@@ -6194,7 +6156,7 @@ impl Timeline {
self.remote_client
.schedule_compaction_update(&remove_layers, new_deltas)?;
drop_layer_manager_wlock(guard);
drop_wlock(guard);
Ok(())
}
@@ -6204,7 +6166,7 @@ impl Timeline {
mut replace_layers: Vec<(Layer, ResidentLayer)>,
mut drop_layers: Vec<Layer>,
) -> Result<(), CompactionError> {
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;
let mut guard = self.layers.write().await;
// Trim our lists in case our caller (compaction) raced with someone else (GC) removing layers: we want
// to avoid double-removing, and avoid rewriting something that was removed.
@@ -6555,10 +6517,7 @@ impl Timeline {
// 5. newer on-disk image layers cover the layer's whole key range
//
// TODO holding a write lock is too agressive and avoidable
let mut guard = self
.layers
.write(LayerManagerLockHolder::GarbageCollection)
.await;
let mut guard = self.layers.write().await;
let layers = guard.layer_map()?;
'outer: for l in layers.iter_historic_layers() {
result.layers_total += 1;
@@ -6860,10 +6819,7 @@ impl Timeline {
use pageserver_api::models::DownloadRemoteLayersTaskState;
let remaining = {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
let Ok(lm) = guard.layer_map() else {
// technically here we could look into iterating accessible layers, but downloading
// all layers of a shutdown timeline makes no sense regardless.
@@ -6969,7 +6925,7 @@ impl Timeline {
impl Timeline {
/// Returns non-remote layers for eviction.
pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
let guard = self.layers.read().await;
let mut max_layer_size: Option<u64> = None;
let resident_layers = guard
@@ -7070,7 +7026,7 @@ impl Timeline {
let image_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
info!("force created image layer {}", image_layer.local_path());
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
guard
.open_mut()
.unwrap()
@@ -7133,7 +7089,7 @@ impl Timeline {
let delta_layer = Layer::finish_creating(self.conf, self, desc, &path)?;
info!("force created delta layer {}", delta_layer.local_path());
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
guard
.open_mut()
.unwrap()
@@ -7228,7 +7184,7 @@ impl Timeline {
// Link the layer to the layer map
{
let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await;
let mut guard = self.layers.write().await;
let layer_map = guard.open_mut().unwrap();
layer_map.force_insert_in_memory_layer(Arc::new(layer));
}
@@ -7245,7 +7201,7 @@ impl Timeline {
io_concurrency: IoConcurrency,
) -> anyhow::Result<Vec<(Key, Bytes)>> {
let mut all_data = Vec::new();
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
let guard = self.layers.read().await;
for layer in guard.layer_map()?.iter_historic_layers() {
if !layer.is_delta() && layer.image_layer_lsn() == lsn {
let layer = guard.get_from_desc(&layer);
@@ -7274,7 +7230,7 @@ impl Timeline {
self: &Arc<Timeline>,
) -> anyhow::Result<Vec<super::storage_layer::PersistentLayerKey>> {
let mut layers = Vec::new();
let guard = self.layers.read(LayerManagerLockHolder::Testing).await;
let guard = self.layers.read().await;
for layer in guard.layer_map()?.iter_historic_layers() {
layers.push(layer.key());
}
@@ -7386,7 +7342,7 @@ impl TimelineWriter<'_> {
let l0_count = self
.tl
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.read()
.await
.layer_map()?
.level0_deltas()
@@ -7605,7 +7561,6 @@ mod tests {
use crate::tenant::harness::{TenantHarness, test_img};
use crate::tenant::layer_map::LayerMap;
use crate::tenant::storage_layer::{Layer, LayerName, LayerVisibilityHint};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::timeline::{DeltaLayerTestDesc, EvictionError};
use crate::tenant::{PreviousHeatmap, Timeline};
@@ -7713,7 +7668,7 @@ mod tests {
// Evict all the layers and stash the old heatmap in the timeline.
// This simulates a migration to a cold secondary location.
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = timeline.layers.read().await;
let mut all_layers = Vec::new();
let forever = std::time::Duration::from_secs(120);
for layer in guard.likely_resident_layers() {
@@ -7835,7 +7790,7 @@ mod tests {
})));
// Evict all the layers in the previous heatmap
let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await;
let guard = timeline.layers.read().await;
let forever = std::time::Duration::from_secs(120);
for layer in guard.likely_resident_layers() {
layer.evict_and_wait(forever).await.unwrap();
@@ -7898,10 +7853,7 @@ mod tests {
}
async fn find_some_layer(timeline: &Timeline) -> Layer {
let layers = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let layers = timeline.layers.read().await;
let desc = layers
.layer_map()
.unwrap()

View File

@@ -4,7 +4,6 @@ use std::ops::Range;
use utils::lsn::Lsn;
use super::Timeline;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
#[derive(serde::Serialize)]
pub(crate) struct RangeAnalysis {
@@ -25,10 +24,7 @@ impl Timeline {
let num_of_l0;
let all_layer_files = {
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
num_of_l0 = guard.layer_map().unwrap().level0_deltas().len();
guard.all_persistent_layers()
};

View File

@@ -9,7 +9,7 @@ use std::ops::{Deref, Range};
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::layer_manager::{LayerManagerLockHolder, LayerManagerReadGuard};
use super::layer_manager::LayerManager;
use super::{
CompactFlags, CompactOptions, CompactionError, CreateImageLayersError, DurationRecorder,
GetVectoredError, ImageLayerCreationMode, LastImageLayerCreationStatus, RecordedDuration,
@@ -62,7 +62,7 @@ use crate::tenant::storage_layer::{
use crate::tenant::tasks::log_compaction_error;
use crate::tenant::timeline::{
DeltaLayerWriter, ImageLayerCreationOutcome, ImageLayerWriter, IoConcurrency, Layer,
ResidentLayer, drop_layer_manager_rlock,
ResidentLayer, drop_rlock,
};
use crate::tenant::{DeltaLayer, MaybeOffloaded};
use crate::virtual_file::{MaybeFatalIo, VirtualFile};
@@ -314,10 +314,7 @@ impl GcCompactionQueue {
.unwrap_or(Lsn::INVALID);
let layers = {
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -411,10 +408,7 @@ impl GcCompactionQueue {
timeline: &Arc<Timeline>,
lsn: Lsn,
) -> Result<u64, CompactionError> {
let guard = timeline
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = timeline.layers.read().await;
let layer_map = guard.layer_map()?;
let layers = layer_map.iter_historic_layers().collect_vec();
let mut size = 0;
@@ -857,7 +851,7 @@ impl KeyHistoryRetention {
}
let layer_generation;
{
let guard = tline.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = tline.layers.read().await;
if !guard.contains_key(key) {
return false;
}
@@ -1288,10 +1282,7 @@ impl Timeline {
// We do the repartition on the L0-L1 boundary. All data below the boundary
// are compacted by L0 with low read amplification, thus making the `repartition`
// function run fast.
let guard = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let guard = self.layers.read().await;
guard
.all_persistent_layers()
.iter()
@@ -1470,7 +1461,7 @@ impl Timeline {
let latest_gc_cutoff = self.get_applied_gc_cutoff_lsn();
let pitr_cutoff = self.gc_info.read().unwrap().cutoffs.time;
let layers = self.layers.read(LayerManagerLockHolder::Compaction).await;
let layers = self.layers.read().await;
let layers_iter = layers.layer_map()?.iter_historic_layers();
let (layers_total, mut layers_checked) = (layers_iter.len(), 0);
for layer_desc in layers_iter {
@@ -1731,10 +1722,7 @@ impl Timeline {
// are implicitly left visible, because LayerVisibilityHint's default is Visible, and we never modify it here.
// Note that L0 deltas _can_ be covered by image layers, but we consider them 'visible' because we anticipate that
// they will be subject to L0->L1 compaction in the near future.
let layer_manager = self
.layers
.read(LayerManagerLockHolder::GetLayerMapInfo)
.await;
let layer_manager = self.layers.read().await;
let layer_map = layer_manager.layer_map()?;
let readable_points = {
@@ -1787,7 +1775,7 @@ impl Timeline {
};
let begin = tokio::time::Instant::now();
let phase1_layers_locked = self.layers.read(LayerManagerLockHolder::Compaction).await;
let phase1_layers_locked = self.layers.read().await;
let now = tokio::time::Instant::now();
stats.read_lock_acquisition_micros =
DurationRecorder::Recorded(RecordedDuration(now - begin), now);
@@ -1815,7 +1803,7 @@ impl Timeline {
/// Level0 files first phase of compaction, explained in the [`Self::compact_legacy`] comment.
async fn compact_level0_phase1<'a>(
self: &'a Arc<Self>,
guard: LayerManagerReadGuard<'a>,
guard: tokio::sync::RwLockReadGuard<'a, LayerManager>,
mut stats: CompactLevel0Phase1StatsBuilder,
target_file_size: u64,
force_compaction_ignore_threshold: bool,
@@ -2041,7 +2029,7 @@ impl Timeline {
holes
};
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
drop_layer_manager_rlock(guard);
drop_rlock(guard);
if self.cancel.is_cancelled() {
return Err(CompactionError::ShuttingDown);
@@ -2481,7 +2469,7 @@ impl Timeline {
// Find the top of the historical layers
let end_lsn = {
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let layers = guard.layer_map()?;
let l0_deltas = layers.level0_deltas();
@@ -3020,7 +3008,7 @@ impl Timeline {
}
split_key_ranges.sort();
let all_layers = {
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -3124,12 +3112,12 @@ impl Timeline {
.await?;
let jobs_len = jobs.len();
for (idx, job) in jobs.into_iter().enumerate() {
let sub_compaction_progress = format!("{}/{}", idx + 1, jobs_len);
info!(
"running enhanced gc bottom-most compaction, sub-compaction {}/{}",
idx + 1,
jobs_len
);
self.compact_with_gc_inner(cancel, job, ctx, yield_for_l0)
.instrument(info_span!(
"sub_compaction",
sub_compaction_progress = sub_compaction_progress
))
.await?;
}
if jobs_len == 0 {
@@ -3197,10 +3185,7 @@ impl Timeline {
// 1. If a layer is in the selection, all layers below it are in the selection.
// 2. Inferred from (1), for each key in the layer selection, the value can be reconstructed only with the layers in the layer selection.
let job_desc = {
let guard = self
.layers
.read(LayerManagerLockHolder::GarbageCollection)
.await;
let guard = self.layers.read().await;
let layers = guard.layer_map()?;
let gc_info = self.gc_info.read().unwrap();
let mut retain_lsns_below_horizon = Vec::new();
@@ -3971,10 +3956,7 @@ impl Timeline {
// First, do a sanity check to ensure the newly-created layer map does not contain overlaps.
let all_layers = {
let guard = self
.layers
.read(LayerManagerLockHolder::GarbageCollection)
.await;
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
layer_map.iter_historic_layers().collect_vec()
};
@@ -4038,10 +4020,7 @@ impl Timeline {
let update_guard = self.gc_compaction_layer_update_lock.write().await;
// Acquiring the update guard ensures current read operations end and new read operations are blocked.
// TODO: can we use `latest_gc_cutoff` Rcu to achieve the same effect?
let mut guard = self
.layers
.write(LayerManagerLockHolder::GarbageCollection)
.await;
let mut guard = self.layers.write().await;
guard
.open_mut()?
.finish_gc_compaction(&layer_selection, &compact_to, &self.metrics);
@@ -4109,11 +4088,7 @@ impl TimelineAdaptor {
pub async fn flush_updates(&mut self) -> Result<(), CompactionError> {
let layers_to_delete = {
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
self.layers_to_delete
.iter()
.map(|x| guard.get_from_desc(x))
@@ -4158,11 +4133,7 @@ impl CompactionJobExecutor for TimelineAdaptor {
) -> anyhow::Result<Vec<OwnArc<PersistentLayerDesc>>> {
self.flush_updates().await?;
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
let layer_map = guard.layer_map()?;
let result = layer_map
@@ -4201,11 +4172,7 @@ impl CompactionJobExecutor for TimelineAdaptor {
// this is a lot more complex than a simple downcast...
if layer.is_delta() {
let l = {
let guard = self
.timeline
.layers
.read(LayerManagerLockHolder::Compaction)
.await;
let guard = self.timeline.layers.read().await;
guard.get_from_desc(layer)
};
let result = l.download_and_keep_resident(ctx).await?;

View File

@@ -19,7 +19,7 @@ use utils::id::TimelineId;
use utils::lsn::Lsn;
use utils::sync::gate::GateError;
use super::layer_manager::{LayerManager, LayerManagerLockHolder};
use super::layer_manager::LayerManager;
use super::{FlushLayerError, Timeline};
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::TaskKind;
@@ -199,10 +199,7 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
.await;
let layers = detached.layers.read().await;
for layer in layers.all_persistent_layers() {
if !layer.is_delta
&& layer.lsn_range.start == image_lsn
@@ -426,7 +423,7 @@ pub(super) async fn prepare(
// we do not need to start from our layers, because they can only be layers that come
// *after* ancestor_lsn
let layers = tokio::select! {
guard = ancestor.layers.read(LayerManagerLockHolder::DetachAncestor) => guard,
guard = ancestor.layers.read() => guard,
_ = detached.cancel.cancelled() => {
return Err(ShuttingDown);
}
@@ -872,12 +869,7 @@ async fn remote_copy(
// Double check that the file is orphan (probably from an earlier attempt), then delete it
let key = file_name.clone().into();
if adoptee
.layers
.read(LayerManagerLockHolder::DetachAncestor)
.await
.contains_key(&key)
{
if adoptee.layers.read().await.contains_key(&key) {
// We are supposed to filter out such cases before coming to this function
return Err(Error::Prepare(anyhow::anyhow!(
"layer file {file_name} already present and inside layer map"

View File

@@ -33,7 +33,6 @@ use crate::tenant::size::CalculateSyntheticSizeError;
use crate::tenant::storage_layer::LayerVisibilityHint;
use crate::tenant::tasks::{BackgroundLoopKind, BackgroundLoopSemaphorePermit, sleep_random};
use crate::tenant::timeline::EvictionError;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
use crate::tenant::{LogicalSizeCalculationCause, TenantShard};
#[derive(Default)]
@@ -209,7 +208,7 @@ impl Timeline {
let mut js = tokio::task::JoinSet::new();
{
let guard = self.layers.read(LayerManagerLockHolder::Eviction).await;
let guard = self.layers.read().await;
guard
.likely_resident_layers()

View File

@@ -15,7 +15,6 @@ use super::{Timeline, TimelineDeleteProgress};
use crate::context::RequestContext;
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
use crate::tenant::metadata::TimelineMetadata;
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
mod flow;
mod importbucket_client;
@@ -164,10 +163,7 @@ async fn prepare_import(
info!("wipe the slate clean");
{
// TODO: do we need to hold GC lock for this?
let mut guard = timeline
.layers
.write(LayerManagerLockHolder::ImportPgData)
.await;
let mut guard = timeline.layers.write().await;
assert!(
guard.layer_map()?.open_layer.is_none(),
"while importing, there should be no in-memory layer" // this just seems like a good place to assert it

View File

@@ -56,7 +56,6 @@ use crate::pgdatadir_mapping::{
};
use crate::task_mgr::TaskKind;
use crate::tenant::storage_layer::{AsLayerDesc, ImageLayerWriter, Layer};
use crate::tenant::timeline::layer_manager::LayerManagerLockHolder;
pub async fn run(
timeline: Arc<Timeline>,
@@ -985,10 +984,7 @@ impl ChunkProcessingJob {
let (desc, path) = writer.finish(ctx).await?;
{
let guard = timeline
.layers
.read(LayerManagerLockHolder::ImportPgData)
.await;
let guard = timeline.layers.read().await;
let existing_layer = guard.try_get_from_key(&desc.key());
if let Some(layer) = existing_layer {
if layer.metadata().generation == timeline.generation {
@@ -1011,10 +1007,7 @@ impl ChunkProcessingJob {
// certain that the existing layer is identical to the new one, so in that case
// we replace the old layer with the one we just generated.
let mut guard = timeline
.layers
.write(LayerManagerLockHolder::ImportPgData)
.await;
let mut guard = timeline.layers.write().await;
let existing_layer = guard
.try_get_from_key(&resident_layer.layer_desc().key())
@@ -1043,7 +1036,7 @@ impl ChunkProcessingJob {
}
}
crate::tenant::timeline::drop_layer_manager_wlock(guard);
crate::tenant::timeline::drop_wlock(guard);
timeline
.remote_client

View File

@@ -1,8 +1,5 @@
use std::collections::HashMap;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use itertools::Itertools;
@@ -23,155 +20,6 @@ use crate::tenant::storage_layer::{
PersistentLayerKey, ReadableLayerWeak, ResidentLayer,
};
/// Warn if the lock was held for longer than this threshold.
/// It's very generous and we should bring this value down over time.
const LAYER_MANAGER_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(5);
const LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD: Duration = Duration::from_secs(30);
/// Describes the operation that is holding the layer manager lock
#[derive(Debug, Clone, Copy, strum_macros::Display)]
#[strum(serialize_all = "kebab_case")]
pub(crate) enum LayerManagerLockHolder {
GetLayerMapInfo,
GenerateHeatmap,
GetPage,
Init,
LoadLayerMap,
GetLayerForWrite,
TryFreezeLayer,
FlushFrozenLayer,
FlushLoop,
Compaction,
GarbageCollection,
Shutdown,
ImportPgData,
DetachAncestor,
Eviction,
#[cfg(test)]
Testing,
}
/// Wrapper for the layer manager that tracks the amount of time during which
/// it was held under read or write lock
#[derive(Default)]
pub(crate) struct LockedLayerManager {
locked: tokio::sync::RwLock<LayerManager>,
}
pub(crate) struct LayerManagerReadGuard<'a> {
guard: ManuallyDrop<tokio::sync::RwLockReadGuard<'a, LayerManager>>,
acquired_at: std::time::Instant,
holder: LayerManagerLockHolder,
}
pub(crate) struct LayerManagerWriteGuard<'a> {
guard: ManuallyDrop<tokio::sync::RwLockWriteGuard<'a, LayerManager>>,
acquired_at: std::time::Instant,
holder: LayerManagerLockHolder,
}
impl Drop for LayerManagerReadGuard<'_> {
fn drop(&mut self) {
// Drop the lock first, before potentially warning if it was held for too long.
// SAFETY: ManuallyDrop in Drop implementation
unsafe { ManuallyDrop::drop(&mut self.guard) };
let held_for = self.acquired_at.elapsed();
if held_for >= LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD {
tracing::warn!(
holder=%self.holder,
"Layer manager read lock held for {}s",
held_for.as_secs_f64(),
);
}
}
}
impl Drop for LayerManagerWriteGuard<'_> {
fn drop(&mut self) {
// Drop the lock first, before potentially warning if it was held for too long.
// SAFETY: ManuallyDrop in Drop implementation
unsafe { ManuallyDrop::drop(&mut self.guard) };
let held_for = self.acquired_at.elapsed();
if held_for >= LAYER_MANAGER_LOCK_WARN_THRESHOLD {
tracing::warn!(
holder=%self.holder,
"Layer manager write lock held for {}s",
held_for.as_secs_f64(),
);
}
}
}
impl Deref for LayerManagerReadGuard<'_> {
type Target = LayerManager;
fn deref(&self) -> &Self::Target {
self.guard.deref()
}
}
impl Deref for LayerManagerWriteGuard<'_> {
type Target = LayerManager;
fn deref(&self) -> &Self::Target {
self.guard.deref()
}
}
impl DerefMut for LayerManagerWriteGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.guard.deref_mut()
}
}
impl LockedLayerManager {
pub(crate) async fn read(&self, holder: LayerManagerLockHolder) -> LayerManagerReadGuard {
let guard = ManuallyDrop::new(self.locked.read().await);
LayerManagerReadGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
}
}
pub(crate) fn try_read(
&self,
holder: LayerManagerLockHolder,
) -> Result<LayerManagerReadGuard, tokio::sync::TryLockError> {
let guard = ManuallyDrop::new(self.locked.try_read()?);
Ok(LayerManagerReadGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
})
}
pub(crate) async fn write(&self, holder: LayerManagerLockHolder) -> LayerManagerWriteGuard {
let guard = ManuallyDrop::new(self.locked.write().await);
LayerManagerWriteGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
}
}
pub(crate) fn try_write(
&self,
holder: LayerManagerLockHolder,
) -> Result<LayerManagerWriteGuard, tokio::sync::TryLockError> {
let guard = ManuallyDrop::new(self.locked.try_write()?);
Ok(LayerManagerWriteGuard {
guard,
acquired_at: std::time::Instant::now(),
holder,
})
}
}
/// Provides semantic APIs to manipulate the layer map.
pub(crate) enum LayerManager {
/// Open as in not shutdown layer manager; we still have in-memory layers and we can manipulate

View File

@@ -1092,15 +1092,13 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
MyPState->ring_last <= ring_index);
}
/* Internal version. Returns the ring index of the last block (result of this function is used only
* when nblocks==1)
*/
/* internal version. Returns the ring index */
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
{
uint64 last_ring_index;
uint64 min_ring_index;
PrefetchRequest hashkey;
#ifdef USE_ASSERT_CHECKING
bool any_hits = false;
@@ -1124,12 +1122,13 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
last_ring_index = UINT64_MAX;
min_ring_index = UINT64_MAX;
for (int i = 0; i < nblocks; i++)
{
PrefetchRequest *slot = NULL;
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
@@ -1153,12 +1152,12 @@ Retry:
if (entry != NULL)
{
slot = entry->slot;
last_ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(last_ring_index));
ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(ring_index));
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
@@ -1170,9 +1169,9 @@ Retry:
if (!neon_prefetch_response_usable(lsns, slot))
{
/* Wait for the old request to finish and discard it */
if (!prefetch_wait_for(last_ring_index))
if (!prefetch_wait_for(ring_index))
goto Retry;
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
pgBufferUsage.prefetch.expired += 1;
@@ -1189,12 +1188,13 @@ Retry:
*/
if (slot->status == PRFS_TAG_REMAINS)
{
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
}
else
{
min_ring_index = Min(min_ring_index, ring_index);
/* The buffered request is good enough, return that index */
if (is_prefetch)
pgBufferUsage.prefetch.duplicates++;
@@ -1283,12 +1283,12 @@ Retry:
* The next buffer pointed to by `ring_unused` is now definitely empty, so
* we can insert the new request to it.
*/
last_ring_index = MyPState->ring_unused;
ring_index = MyPState->ring_unused;
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index <= MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index <= MyPState->ring_unused);
slot = GetPrfSlotNoCheck(last_ring_index);
slot = GetPrfSlotNoCheck(ring_index);
Assert(slot->status == PRFS_UNUSED);
@@ -1298,9 +1298,11 @@ Retry:
*/
slot->buftag = hashkey.buftag;
slot->shard_no = get_shard_number(&tag);
slot->my_ring_index = last_ring_index;
slot->my_ring_index = ring_index;
slot->flags = 0;
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
@@ -1313,12 +1315,11 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
Assert(any_hits);
Assert(last_ring_index != UINT64_MAX);
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= min_ring_index &&
min_ring_index < MyPState->ring_unused);
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
@@ -1334,7 +1335,7 @@ Retry:
MyPState->ring_flush = MyPState->ring_unused;
}
return last_ring_index;
return min_ring_index;
}
static bool

View File

@@ -2,6 +2,6 @@ DROP FUNCTION IF EXISTS get_prewarm_info(out total_pages integer, out prewarmed_
DROP FUNCTION IF EXISTS get_local_cache_state(max_chunks integer);
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer);
DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer default 1);

View File

@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
wp->propTermStartLsn = sk->voteResponse.flushLsn;
wp->donor = sk;
}
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
if (n_votes > 0)
appendStringInfoString(s, ", ");

10
poetry.lock generated
View File

@@ -3051,19 +3051,19 @@ files = [
[[package]]
name = "requests"
version = "2.32.4"
version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
certifi = ">=2017.4.17"
charset_normalizer = ">=2,<4"
charset-normalizer = ">=2,<4"
idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<3"
@@ -3846,4 +3846,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = "^3.11"
content-hash = "bd93313f110110aa53b24a3ed47ba2d7f60e2c658a79cdff7320fed1bb1b57b5"
content-hash = "7ab1e7b975af34b3271b7c6018fa22a261d3f73c7c0a0403b6b2bb86b5fbd36e"

View File

@@ -14,9 +14,9 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::stream::PqStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -109,7 +109,7 @@ impl ConsoleRedirectBackend {
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestContext,

View File

@@ -4,8 +4,6 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use arc_swap::ArcSwapOption;
use base64::Engine as _;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use clashmap::ClashMap;
use jose_jwk::crypto::KeyInfo;
use reqwest::{Client, redirect};
@@ -349,17 +347,17 @@ impl JwkCacheEntryLock {
.split_once('.')
.ok_or(JwtEncodingError::InvalidCompactForm)?;
let header = BASE64_URL_SAFE_NO_PAD.decode(header)?;
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?;
let header = serde_json::from_slice::<JwtHeader<'_>>(&header)?;
let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)?;
if let Some(iss) = &payload.issuer {
ctx.set_jwt_issuer(iss.as_ref().to_owned());
}
let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?;
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?;
let kid = header.key_id.ok_or(JwtError::MissingKeyId)?;
@@ -798,6 +796,7 @@ mod tests {
use std::net::SocketAddr;
use std::time::SystemTime;
use base64::URL_SAFE_NO_PAD;
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
@@ -872,8 +871,9 @@ mod tests {
key_id: Some(Cow::Owned(kid)),
};
let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap());
let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap());
let header =
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
format!("{header}.{body}")
}
@@ -883,7 +883,7 @@ mod tests {
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
@@ -893,7 +893,7 @@ mod tests {
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
@@ -904,7 +904,7 @@ mod tests {
let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256);
let sig = SigningKey::<sha2::Sha256>::new(key).sign(payload.as_bytes());
let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}

View File

@@ -14,21 +14,20 @@ use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
@@ -231,8 +230,11 @@ async fn auth_quirks(
config.is_vpc_acccess_proxy,
)?;
access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?;
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
@@ -399,20 +401,19 @@ impl Backend<'_, ComputeUserInfo> {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
}),
}
}
}
#[async_trait::async_trait]
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
@@ -438,7 +439,6 @@ mod tests {
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
};
@@ -477,7 +477,6 @@ mod tests {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
rate_limits: EndpointRateLimitConfig::default(),
})
}

View File

@@ -28,9 +28,10 @@ use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
project_git_version!(GIT_VERSION);

View File

@@ -11,13 +11,11 @@ use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
use tracing::{Instrument, info, warn};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
@@ -316,7 +314,7 @@ pub async fn run() -> anyhow::Result<()> {
let jemalloc = match crate::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
error!(error = ?e, "could not start jemalloc metrics loop");
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
None
}
};
@@ -522,44 +520,23 @@ pub async fn run() -> anyhow::Result<()> {
}
}
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client {
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
maintenance_tasks.spawn(async move {
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
}
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
}
if let Some(regional_redis_client) = regional_redis_client {

View File

@@ -364,7 +364,6 @@ mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -400,7 +399,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret1.clone(),
@@ -416,7 +414,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret2.clone(),
@@ -442,7 +439,6 @@ mod tests {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret3.clone(),

View File

@@ -136,11 +136,11 @@ impl AuthInfo {
}
}
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},

View File

@@ -11,12 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
use crate::util::run_until_cancelled;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
pub async fn task_main(
config: &'static ProxyConfig,

View File

@@ -146,7 +146,6 @@ impl NeonControlPlaneClient {
public_access_blocked: block_public_connections,
vpc_access_blocked: block_vpc_connections,
},
rate_limits: body.rate_limits,
})
}
.inspect_err(|e| tracing::debug!(error = ?e))
@@ -313,7 +312,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
@@ -359,7 +357,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,

View File

@@ -20,7 +20,7 @@ use crate::context::RequestContext;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
@@ -130,7 +130,6 @@ impl MockControlPlane {
project_id: None,
account_id: None,
access_blocker_flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
})
}
@@ -234,7 +233,6 @@ impl super::ControlPlaneApi for MockControlPlane {
allowed_ips: Arc::new(info.allowed_ips),
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
flags: info.access_blocker_flags,
rate_limits: info.rate_limits,
})
}

View File

@@ -10,7 +10,6 @@ use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::endpoints::EndpointsCache;
@@ -23,6 +22,8 @@ use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::types::EndpointId;
use super::{EndpointAccessControl, RoleAccessControl};
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {

View File

@@ -227,35 +227,12 @@ pub(crate) struct UserFacingMessage {
#[derive(Deserialize)]
pub(crate) struct GetEndpointAccessControl {
pub(crate) role_secret: Box<str>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<String>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) account_id: Option<AccountIdInt>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
#[serde(default)]
pub(crate) rate_limits: EndpointRateLimitConfig,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct EndpointRateLimitConfig {
pub connection_attempts: ConnectionAttemptsLimit,
}
#[derive(Copy, Clone, Deserialize, Default)]
pub struct ConnectionAttemptsLimit {
pub tcp: Option<LeakyBucketSetting>,
pub ws: Option<LeakyBucketSetting>,
pub http: Option<LeakyBucketSetting>,
}
#[derive(Copy, Clone, Deserialize)]
pub struct LeakyBucketSetting {
pub rps: f64,
pub burst: f64,
}
/// Response which holds compute node's `host:port` pair.

View File

@@ -11,8 +11,6 @@ pub(crate) mod errors;
use std::sync::Arc;
use messages::EndpointRateLimitConfig;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
@@ -20,9 +18,8 @@ use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::intern::{AccountIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, scram};
@@ -59,8 +56,6 @@ pub(crate) struct AuthInfo {
pub(crate) account_id: Option<AccountIdInt>,
/// Are public connections or VPC connections blocked?
pub(crate) access_blocker_flags: AccessBlockerFlags,
/// The rate limits for this endpoint.
pub(crate) rate_limits: EndpointRateLimitConfig,
}
/// Info for establishing a connection to a compute node.
@@ -106,8 +101,6 @@ pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,
pub flags: AccessBlockerFlags,
pub rate_limits: EndpointRateLimitConfig,
}
impl EndpointAccessControl {
@@ -146,36 +139,6 @@ impl EndpointAccessControl {
Ok(())
}
pub fn connection_attempt_rate_limit(
&self,
ctx: &RequestContext,
endpoint: &EndpointId,
rate_limiter: &EndpointRateLimiter,
) -> Result<(), AuthError> {
let endpoint = EndpointIdInt::from(endpoint);
let limits = &self.rate_limits.connection_attempts;
let config = match ctx.protocol() {
crate::metrics::Protocol::Http => limits.http,
crate::metrics::Protocol::Ws => limits.ws,
crate::metrics::Protocol::Tcp => limits.tcp,
crate::metrics::Protocol::SniRouter => return Ok(()),
};
let config = config.and_then(|config| {
if config.rps <= 0.0 || config.burst <= 0.0 {
return None;
}
Some(LeakyBucketConfig::new(config.rps, config.burst))
});
if !rate_limiter.check(endpoint, config, 1) {
return Err(AuthError::too_many_connections());
}
Ok(())
}
}
/// This will allocate per each call, but the http requests alone

View File

@@ -106,5 +106,4 @@ mod tls;
mod types;
mod url;
mod usage_metrics;
mod util;
mod waiters;

View File

@@ -8,19 +8,19 @@ use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
@@ -49,6 +49,14 @@ pub(crate) trait ConnectMechanism {
) -> Result<Self::Connection, Self::ConnectError>;
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
@@ -83,7 +91,7 @@ impl ConnectMechanism for TcpMechanism {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,

View File

@@ -1,3 +1,4 @@
pub mod connect_compute;
pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;

View File

@@ -1,10 +1,8 @@
#[cfg(test)]
mod tests;
pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::sync::Arc;
use futures::FutureExt;
@@ -23,16 +21,15 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
@@ -49,6 +46,21 @@ impl ReportableError for TlsRequired {
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match futures::future::select(
std::pin::pin!(f),
std::pin::pin!(cancellation_token.cancelled()),
)
.await
{
futures::future::Either::Left((f, _)) => Some(f),
futures::future::Either::Right(((), _)) => None,
}
}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
@@ -346,12 +358,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
@@ -361,7 +373,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth: auth_info,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -19,13 +19,15 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
@@ -573,13 +575,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeUserInfo> {
) -> auth::Backend<'static, ComputeCredentials> {
auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
},
)
}

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::config::RetryConfig;
@@ -9,6 +8,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::proxy::retry::{retry_after, should_retry};
// Use macro to retain original callsite.
@@ -23,12 +23,7 @@ macro_rules! log_wake_compute_error {
};
}
#[async_trait]
pub(crate) trait WakeComputeBackend {
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
}
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &RequestContext,
api: &B,

View File

@@ -69,8 +69,9 @@ pub struct LeakyBucketConfig {
pub max: f64,
}
#[cfg(test)]
impl LeakyBucketConfig {
pub fn new(rps: f64, max: f64) -> Self {
pub(crate) fn new(rps: f64, max: f64) -> Self {
assert!(rps > 0.0, "rps must be positive");
assert!(max > 0.0, "max must be positive");
Self { rps, max }

View File

@@ -12,10 +12,11 @@ use rand::{Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
use super::LeakyBucketConfig;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
use super::LeakyBucketConfig;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,

View File

@@ -1,8 +1,5 @@
//! Definition and parser for channel binding flag (a part of the `GS2` header).
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
/// Channel binding flag (possibly with params).
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum ChannelBinding<T> {
@@ -58,7 +55,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
base64::encode(&cbind_input).into()
}
})
}
@@ -73,9 +70,9 @@ mod tests {
use ChannelBinding::*;
let cases = [
(NotSupportedClient, BASE64_STANDARD.encode("n,,")),
(NotSupportedServer, BASE64_STANDARD.encode("y,,")),
(Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")),
(NotSupportedClient, base64::encode("n,,")),
(NotSupportedServer, base64::encode("y,,")),
(Required("foo"), base64::encode("p=foo,,bar")),
];
for (cb, input) in cases {

View File

@@ -2,8 +2,6 @@
use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -107,7 +105,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -3,9 +3,6 @@
use std::fmt;
use std::ops::Range;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use super::base64_decode_array;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use super::signature::SignatureBuilder;
@@ -91,7 +88,7 @@ impl<'a> ClientFirstMessage<'a> {
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
BASE64_STANDARD.encode_string(nonce, &mut message);
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
@@ -145,7 +142,11 @@ impl<'a> ClientFinalMessage<'a> {
server_key: &ScramKey,
) -> String {
let mut buf = String::from("v=");
BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf);
base64::encode_config_buf(
signature_builder.build(server_key),
base64::STANDARD,
&mut buf,
);
buf
}
@@ -250,7 +251,7 @@ mod tests {
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
BASE64_STANDARD.encode(msg.proof),
base64::encode(msg.proof),
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}

View File

@@ -15,8 +15,6 @@ mod secret;
mod signature;
pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
@@ -34,7 +32,7 @@ pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
let mut bytes = [0u8; N];
let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?;
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
if size != N {
return None;
}

View File

@@ -1,7 +1,5 @@
//! Tools for SCRAM server secret management.
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
@@ -58,7 +56,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: base64::encode(nonce),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -90,7 +88,7 @@ mod tests {
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
}
}

View File

@@ -21,7 +21,7 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
@@ -68,20 +68,17 @@ impl PoolingBackend {
self.config.authentication_config.is_vpc_acccess_proxy,
)?;
access_control.connection_attempt_rate_limit(
ctx,
&user_info.endpoint,
&self.endpoint_rate_limiter,
)?;
let ep = EndpointIdInt::from(&user_info.endpoint);
let rate_limit_config = None;
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let role_access = backend.get_role_secret(ctx).await?;
let Some(secret) = role_access.secret else {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
@@ -183,15 +180,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::pglb::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
@@ -218,15 +214,18 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
crate::pglb::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
@@ -496,7 +495,6 @@ struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -522,10 +520,6 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);

View File

@@ -16,8 +16,6 @@ use std::sync::atomic::AtomicUsize;
use std::task::{Poll, ready};
use std::time::Duration;
use base64::Engine as _;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use ed25519_dalek::{Signature, Signer, SigningKey};
use futures::Future;
use futures::future::poll_fn;
@@ -348,7 +346,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt.push_str("eyJhbGciOiJFZERTQSJ9.");
// encode the jwt payload in-place
BASE64_URL_SAFE_NO_PAD.encode_string(payload, &mut jwt);
base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
// create the signature from the encoded header || payload
let sig: Signature = sk.sign(jwt.as_bytes());
@@ -356,7 +354,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt.push('.');
// encode the jwt signature in-place
BASE64_URL_SAFE_NO_PAD.encode_string(sig.to_bytes(), &mut jwt);
base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
debug_assert_eq!(
jwt.len(),

View File

@@ -50,10 +50,10 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use crate::util::run_until_cancelled;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";

View File

@@ -41,11 +41,10 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -3,8 +3,6 @@ pub mod postgres_rustls;
pub mod server_config;
use anyhow::Context;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use rustls::pki_types::CertificateDer;
use sha2::{Digest, Sha256};
use tracing::{error, info};
@@ -60,7 +58,7 @@ impl TlsServerEndPoint {
let oid = certificate.signature_algorithm.oid;
if SHA256_OIDS.contains(&oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(%subject, tls_server_end_point = %BASE64_STANDARD.encode(tls_server_end_point), "determined channel binding");
info!(%subject, tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(%subject, "unknown channel binding");

View File

@@ -1,14 +0,0 @@
use std::pin::pin;
use futures::future::{Either, select};
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
Either::Left((f, _)) => Some(f),
Either::Right(((), _)) => None,
}
}

View File

@@ -9,7 +9,7 @@ pytest = "^7.4.4"
psycopg2-binary = "^2.9.10"
typing-extensions = "^4.12.2"
PyJWT = {version = "^2.1.0", extras = ["crypto"]}
requests = "^2.32.4"
requests = "^2.32.3"
pytest-xdist = "^3.3.1"
asyncpg = "^0.30.0"
aiopg = "^1.4.0"

View File

@@ -65,8 +65,9 @@ diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connec
diesel_migrations = { version = "2.2.0" }
scoped-futures = "0.1.4"
compute_api = { path = "../libs/compute_api/" }
http-utils = { path = "../libs/http-utils/" }
utils = { path = "../libs/utils/" }
metrics = { path = "../libs/metrics/" }
control_plane = { path = "../control_plane" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -428,7 +428,10 @@ impl ComputeHook {
.expect("Unknown pageserver");
let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
(pg_host, pg_port.unwrap_or(5432))
compute_api::spec::Pageserver {
host: pg_host,
port: pg_port.unwrap_or(5432),
}
})
.collect::<Vec<_>>();

View File

@@ -1398,31 +1398,6 @@ async fn handle_timeline_import(req: Request<Body>) -> Result<Response<Body>, Ap
)
}
async fn handle_tenant_timeline_locate(
service: Arc<Service>,
req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?;
check_permissions(&req, Scope::Admin)?;
maybe_rate_limit(&req, tenant_id).await;
match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(_req) => {}
};
json_response(
StatusCode::OK,
service
.tenant_timeline_locate(tenant_id, timeline_id)
.await?,
)
}
async fn handle_tenants_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
@@ -2164,16 +2139,6 @@ pub fn make_router(
)
},
)
.get(
"/debug/v1/tenant/:tenant_id/timeline/:timeline_id/locate",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_locate,
RequestName("v1_tenant_timeline_locate"),
)
},
)
.get("/debug/v1/scheduler", |r| {
named_request_span(r, handle_scheduler_dump, RequestName("debug_v1_scheduler"))
})

View File

@@ -2267,7 +2267,6 @@ impl Service {
// fail, and start from scratch, so it doesn't make sense for us to try and preserve
// the stale/multi states at this point.
mode: LocationConfigMode::AttachedSingle,
stripe_size: shard.shard.stripe_size,
});
shard.generation = std::cmp::max(shard.generation, Some(new_gen));
@@ -2301,7 +2300,6 @@ impl Service {
id: *tenant_shard_id,
r#gen: None,
mode: LocationConfigMode::Secondary,
stripe_size: shard.shard.stripe_size,
});
// We must not update observed, because we have no guarantee that our
@@ -6651,8 +6649,6 @@ impl Service {
/// This is for debug/support only: assuming tenant data is already present in S3, we "create" a
/// tenant with a very high generation number so that it will see the existing data.
/// It does not create timelines on safekeepers, because they might already exist on some
/// safekeeper set. So, the timelines are not storcon-managed after the import.
pub(crate) async fn tenant_import(
&self,
tenant_id: TenantId,

View File

@@ -17,7 +17,7 @@ use pageserver_api::controller_api::{
SafekeeperDescribeResponse, SkSchedulingPolicy, TimelineImportRequest,
};
use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use safekeeper_api::membership::{MemberSet, SafekeeperGeneration, SafekeeperId};
use safekeeper_api::membership::{MemberSet, SafekeeperId};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use utils::id::{NodeId, TenantId, TimelineId};
@@ -26,13 +26,6 @@ use utils::lsn::Lsn;
use super::Service;
#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub struct TimelineLocateResponse {
pub generation: SafekeeperGeneration,
pub sk_set: Vec<NodeId>,
pub new_sk_set: Option<Vec<NodeId>>,
}
impl Service {
/// Timeline creation on safekeepers
///
@@ -403,38 +396,6 @@ impl Service {
Ok(())
}
/// Locate safekeepers for a timeline.
/// Return the generation, sk_set and new_sk_set if present.
/// If the timeline is not storcon-managed, return NotFound.
pub(crate) async fn tenant_timeline_locate(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<TimelineLocateResponse, ApiError> {
let timeline = self
.persistence
.get_timeline(tenant_id, timeline_id)
.await?;
let Some(timeline) = timeline else {
return Err(ApiError::NotFound(
anyhow::anyhow!("Timeline {}/{} not found", tenant_id, timeline_id).into(),
));
};
Ok(TimelineLocateResponse {
generation: SafekeeperGeneration::new(timeline.generation as u32),
sk_set: timeline
.sk_set
.iter()
.map(|id| NodeId(*id as u64))
.collect(),
new_sk_set: timeline
.new_sk_set
.map(|sk_set| sk_set.iter().map(|id| NodeId(*id as u64)).collect()),
})
}
/// Perform timeline deletion on safekeepers. Will return success: we persist the deletion into the reconciler.
pub(super) async fn tenant_timeline_delete_safekeepers(
self: &Arc<Self>,

View File

@@ -69,17 +69,15 @@ class EndpointHttpClient(requests.Session):
json: dict[str, str] = res.json()
return json
def prewarm_lfc(self, from_endpoint_id: str | None = None):
url: str = f"http://localhost:{self.external_port}/lfc/prewarm"
params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict()
self.post(url, params=params).raise_for_status()
def prewarm_lfc(self):
self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status()
def prewarmed():
json = self.prewarm_lfc_status()
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(prewarmed, timeout=60)
wait_until(prewarmed)
def offload_lfc(self):
url = f"http://localhost:{self.external_port}/lfc/offload"

View File

@@ -129,18 +129,6 @@ class NeonAPI:
return cast("dict[str, Any]", resp.json())
def get_project_limits(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/limits",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def delete_project(
self,
project_id: str,

View File

@@ -2223,17 +2223,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
shards: list[dict[str, Any]] = body["shards"]
return shards
def timeline_locate(self, tenant_id: TenantId, timeline_id: TimelineId):
"""
:return: dict {"generation": int, "sk_set": [int], "new_sk_set": [int]}
"""
response = self.request(
"GET",
f"{self.api}/debug/v1/tenant/{tenant_id}/timeline/{timeline_id}/locate",
headers=self.headers(TokenScope.ADMIN),
)
return response.json()
def tenant_describe(self, tenant_id: TenantId):
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int, preferred_az_id: str}
@@ -4057,16 +4046,6 @@ def static_proxy(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
vanilla_pg.stop()
vanilla_pg.edit_hba(
[
"local all all trust",
"host all all 127.0.0.1/32 scram-sha-256",
"host all all ::1/128 scram-sha-256",
]
)
vanilla_pg.start()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()

View File

@@ -45,8 +45,6 @@ class NeonEndpoint:
if self.branch.connect_env:
self.connect_env = self.branch.connect_env.copy()
self.connect_env["PGHOST"] = self.host
if self.type == "read_only":
self.project.read_only_endpoints_total += 1
def delete(self):
self.project.delete_endpoint(self.id)
@@ -230,13 +228,8 @@ class NeonProject:
self.benchmarks: dict[str, subprocess.Popen[Any]] = {}
self.restore_num: int = 0
self.restart_pgbench_on_console_errors: bool = False
self.limits: dict[str, Any] = self.get_limits()["limits"]
self.read_only_endpoints_total: int = 0
def get_limits(self) -> dict[str, Any]:
return self.neon_api.get_project_limits(self.id)
def delete(self) -> None:
def delete(self):
self.neon_api.delete_project(self.id)
def create_branch(self, parent_id: str | None = None) -> NeonBranch | None:
@@ -289,7 +282,6 @@ class NeonProject:
self.neon_api.delete_endpoint(self.id, endpoint_id)
self.endpoints[endpoint_id].branch.endpoints.pop(endpoint_id)
self.endpoints.pop(endpoint_id)
self.read_only_endpoints_total -= 1
self.wait()
def start_benchmark(self, target: str, clients: int = 10) -> subprocess.Popen[Any]:
@@ -377,64 +369,49 @@ def setup_class(
print(f"::warning::Retried on 524 error {neon_api.retries524} times")
if neon_api.retries4xx > 0:
print(f"::warning::Retried on 4xx error {neon_api.retries4xx} times")
log.info("Removing the project %s", project.id)
log.info("Removing the project")
project.delete()
def do_action(project: NeonProject, action: str) -> bool:
def do_action(project: NeonProject, action: str) -> None:
"""
Runs the action
"""
log.info("Action: %s", action)
if action == "new_branch":
log.info("Trying to create a new branch")
if 0 <= project.limits["max_branches"] <= len(project.branches):
log.info(
"Maximum branch limit exceeded (%s of %s)",
len(project.branches),
project.limits["max_branches"],
)
return False
parent = project.branches[
random.choice(list(set(project.branches.keys()) - project.reset_branches))
]
log.info("Parent: %s", parent)
child = parent.create_child_branch()
if child is None:
return False
return
log.info("Created branch %s", child)
child.start_benchmark()
elif action == "delete_branch":
if project.leaf_branches:
target: NeonBranch = random.choice(list(project.leaf_branches.values()))
target = random.choice(list(project.leaf_branches.values()))
log.info("Trying to delete branch %s", target)
target.delete()
else:
log.info("Leaf branches not found, skipping")
return False
elif action == "new_ro_endpoint":
if 0 <= project.limits["max_read_only_endpoints"] <= project.read_only_endpoints_total:
log.info(
"Maximum read only endpoint limit exceeded (%s of %s)",
project.read_only_endpoints_total,
project.limits["max_read_only_endpoints"],
)
return False
ep = random.choice(
[br for br in project.branches.values() if br.id not in project.reset_branches]
).create_ro_endpoint()
log.info("Created the RO endpoint with id %s branch: %s", ep.id, ep.branch.id)
ep.start_benchmark()
elif action == "delete_ro_endpoint":
if project.read_only_endpoints_total == 0:
log.info("no read_only endpoints present, skipping")
return False
ro_endpoints: list[NeonEndpoint] = [
endpoint for endpoint in project.endpoints.values() if endpoint.type == "read_only"
]
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
if ro_endpoints:
target_ep: NeonEndpoint = random.choice(ro_endpoints)
target_ep.delete()
log.info("endpoint %s deleted", target_ep.id)
else:
log.info("no read_only endpoints present, skipping")
elif action == "restore_random_time":
if project.leaf_branches:
br: NeonBranch = random.choice(list(project.leaf_branches.values()))
@@ -442,10 +419,8 @@ def do_action(project: NeonProject, action: str) -> bool:
br.restore_random_time()
else:
log.info("No leaf branches found")
return False
else:
raise ValueError(f"The action {action} is unknown")
return True
@pytest.mark.timeout(7200)
@@ -482,9 +457,8 @@ def test_api_random(
pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env)
for _ in range(num_operations):
log.info("Starting action #%s", _ + 1)
while not do_action(
do_action(
project, random.choices([a[0] for a in ACTIONS], weights=[w[1] for w in ACTIONS])[0]
):
log.info("Retrying...")
)
project.check_all_benchmarks()
assert True

View File

@@ -18,7 +18,6 @@ from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
PgBin,
Safekeeper,
flush_ep_to_pageserver,
)
from fixtures.pageserver.http import PageserverApiException
@@ -27,7 +26,6 @@ from fixtures.pageserver.utils import (
)
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage
from fixtures.safekeeper.http import MembershipConfiguration
from fixtures.workload import Workload
if TYPE_CHECKING:
@@ -544,24 +542,6 @@ def test_historic_storage_formats(
# All our artifacts should contain at least one timeline
assert len(timelines) > 0
# Import tenant does not create the timeline on safekeepers,
# because it is a debug handler and the timeline may have already been
# created on some set of safekeepers.
# Create the timeline on safekeepers manually.
# TODO(diko): when we have the script/storcon handler to migrate
# the timeline to storcon, we can replace this code with it.
mconf = MembershipConfiguration(
generation=1,
members=Safekeeper.sks_to_safekeeper_ids([env.safekeepers[0]]),
new_members=None,
)
members_sks = Safekeeper.mconf_sks(env, mconf)
for timeline in timelines:
Safekeeper.create_timeline(
dataset.tenant_id, timeline["timeline_id"], env.pageserver, mconf, members_sks
)
# TODO: ensure that the snapshots we're importing contain a sensible variety of content, at the very
# least they should include a mixture of deltas and image layers. Preferably they should also
# contain some "exotic" stuff like aux files from logical replication.

View File

@@ -188,8 +188,7 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet
pg_cur.execute("select pg_reload_conf()")
if query is LfcQueryMethod.COMPUTE_CTL:
# Same thing as prewarm_lfc(), testing other method
http_client.prewarm_lfc(endpoint.endpoint_id)
http_client.prewarm_lfc()
else:
pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin):
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.create_branch("test_pageserver_restarts")
endpoint = env.endpoints.create_start("test_pageserver_restarts")
@@ -28,11 +28,7 @@ def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: Pg
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
thread = threading.Thread(
target=run_pgbench,
args=(endpoint.connstr(options="-cstatement_timeout=360s"),),
daemon=True,
)
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
thread.start()
for _ in range(n_restarts):

View File

@@ -19,15 +19,11 @@ TABLE_NAME = "neon_control_plane.endpoints"
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
@@ -64,9 +60,7 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
def query(status: int, query: str, *args):
@@ -81,8 +75,6 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
user="proxy",
password="password",
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
query(200, "select 1;") # should work now

View File

@@ -430,7 +430,6 @@ def test_tenant_delete_stale_shards(neon_env_builder: NeonEnvBuilder, pg_bin: Pg
workload.init()
workload.write_rows(256)
workload.validate()
workload.stop()
assert_prefix_not_empty(
neon_env_builder.pageserver_remote_storage,

Some files were not shown because too many files have changed in this diff Show More