Merge commit '296c9190b' into problame/standby-horizon-leases

This commit is contained in:
Christian Schwarz
2025-08-06 17:49:50 +02:00
48 changed files with 946 additions and 407 deletions

8
.gitmodules vendored
View File

@@ -1,16 +1,16 @@
[submodule "vendor/postgres-v14"] [submodule "vendor/postgres-v14"]
path = vendor/postgres-v14 path = vendor/postgres-v14
url = https://github.com/neondatabase/postgres.git url = ../postgres.git
branch = REL_14_STABLE_neon branch = REL_14_STABLE_neon
[submodule "vendor/postgres-v15"] [submodule "vendor/postgres-v15"]
path = vendor/postgres-v15 path = vendor/postgres-v15
url = https://github.com/neondatabase/postgres.git url = ../postgres.git
branch = REL_15_STABLE_neon branch = REL_15_STABLE_neon
[submodule "vendor/postgres-v16"] [submodule "vendor/postgres-v16"]
path = vendor/postgres-v16 path = vendor/postgres-v16
url = https://github.com/neondatabase/postgres.git url = ../postgres.git
branch = REL_16_STABLE_neon branch = REL_16_STABLE_neon
[submodule "vendor/postgres-v17"] [submodule "vendor/postgres-v17"]
path = vendor/postgres-v17 path = vendor/postgres-v17
url = https://github.com/neondatabase/postgres.git url = ../postgres.git
branch = REL_17_STABLE_neon branch = REL_17_STABLE_neon

2
Cargo.lock generated
View File

@@ -5290,6 +5290,7 @@ dependencies = [
"async-trait", "async-trait",
"atomic-take", "atomic-take",
"aws-config", "aws-config",
"aws-credential-types",
"aws-sdk-iam", "aws-sdk-iam",
"aws-sigv4", "aws-sigv4",
"base64 0.22.1", "base64 0.22.1",
@@ -5329,6 +5330,7 @@ dependencies = [
"itoa", "itoa",
"jose-jwa", "jose-jwa",
"jose-jwk", "jose-jwk",
"json",
"lasso", "lasso",
"measured", "measured",
"metrics", "metrics",

View File

@@ -1045,6 +1045,8 @@ impl ComputeNode {
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?, PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
}; };
self.fix_zenith_signal_neon_signal()?;
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros = state.metrics.pageserver_connect_micros =
connected.duration_since(started).as_micros() as u64; connected.duration_since(started).as_micros() as u64;
@@ -1054,6 +1056,27 @@ impl ComputeNode {
Ok(()) Ok(())
} }
/// Move the Zenith signal file to Neon signal file location.
/// This makes Compute compatible with older PageServers that don't yet
/// know about the Zenith->Neon rename.
fn fix_zenith_signal_neon_signal(&self) -> Result<()> {
let datadir = Path::new(&self.params.pgdata);
let neonsig = datadir.join("neon.signal");
if neonsig.is_file() {
return Ok(());
}
let zenithsig = datadir.join("zenith.signal");
if zenithsig.is_file() {
fs::copy(zenithsig, neonsig)?;
}
Ok(())
}
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when /// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup. /// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {

View File

@@ -32,7 +32,8 @@
//! config.json - passed to `compute_ctl` //! config.json - passed to `compute_ctl`
//! pgdata/ //! pgdata/
//! postgresql.conf - copy of postgresql.conf created by `compute_ctl` //! postgresql.conf - copy of postgresql.conf created by `compute_ctl`
//! zenith.signal //! neon.signal
//! zenith.signal - copy of neon.signal, for backward compatibility
//! <other PostgreSQL files> //! <other PostgreSQL files>
//! ``` //! ```
//! //!

View File

@@ -217,6 +217,9 @@ pub struct NeonStorageControllerConf {
pub posthog_config: Option<PostHogConfig>, pub posthog_config: Option<PostHogConfig>,
pub kick_secondary_downloads: Option<bool>, pub kick_secondary_downloads: Option<bool>,
#[serde(with = "humantime_serde")]
pub shard_split_request_timeout: Option<Duration>,
} }
impl NeonStorageControllerConf { impl NeonStorageControllerConf {
@@ -250,6 +253,7 @@ impl Default for NeonStorageControllerConf {
timeline_safekeeper_count: None, timeline_safekeeper_count: None,
posthog_config: None, posthog_config: None,
kick_secondary_downloads: None, kick_secondary_downloads: None,
shard_split_request_timeout: None,
} }
} }
} }

View File

@@ -648,6 +648,13 @@ impl StorageController {
args.push(format!("--timeline-safekeeper-count={sk_cnt}")); args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
} }
if let Some(duration) = self.config.shard_split_request_timeout {
args.push(format!(
"--shard-split-request-timeout={}",
humantime::Duration::from(duration)
));
}
let mut envs = vec![ let mut envs = vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), ("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), ("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),

View File

@@ -129,9 +129,10 @@ segment to bootstrap the WAL writing, but it doesn't contain the checkpoint reco
changes in xlog.c, to allow starting the compute node without reading the last checkpoint record changes in xlog.c, to allow starting the compute node without reading the last checkpoint record
from WAL. from WAL.
This includes code to read the `zenith.signal` file, which tells the startup code the LSN to start This includes code to read the `neon.signal` (also `zenith.signal`) file, which tells the startup
at. When the `zenith.signal` file is present, the startup uses that LSN instead of the last code the LSN to start at. When the `neon.signal` file is present, the startup uses that LSN
checkpoint's LSN. The system is known to be consistent at that LSN, without any WAL redo. instead of the last checkpoint's LSN. The system is known to be consistent at that LSN, without
any WAL redo.
### How to get rid of the patch ### How to get rid of the patch

View File

@@ -31,6 +31,7 @@ pub struct UnreliableWrapper {
/* BEGIN_HADRON */ /* BEGIN_HADRON */
// This the probability of failure for each operation, ranged from [0, 100]. // This the probability of failure for each operation, ranged from [0, 100].
// The probability is default to 100, which means that all operations will fail. // The probability is default to 100, which means that all operations will fail.
// Storage will fail by probability up to attempts_to_fail times.
attempt_failure_probability: u64, attempt_failure_probability: u64,
/* END_HADRON */ /* END_HADRON */
} }

View File

@@ -47,6 +47,7 @@ where
/* BEGIN_HADRON */ /* BEGIN_HADRON */
pub enum DeploymentMode { pub enum DeploymentMode {
Local,
Dev, Dev,
Staging, Staging,
Prod, Prod,
@@ -64,7 +65,7 @@ pub fn get_deployment_mode() -> Option<DeploymentMode> {
} }
}, },
Err(_) => { Err(_) => {
tracing::error!("DEPLOYMENT_MODE not set"); // tracing::error!("DEPLOYMENT_MODE not set");
None None
} }
} }

View File

@@ -114,7 +114,7 @@ where
// Compute postgres doesn't have any previous WAL files, but the first // Compute postgres doesn't have any previous WAL files, but the first
// record that it's going to write needs to include the LSN of the // record that it's going to write needs to include the LSN of the
// previous record (xl_prev). We include prev_record_lsn in the // previous record (xl_prev). We include prev_record_lsn in the
// "zenith.signal" file, so that postgres can read it during startup. // "neon.signal" file, so that postgres can read it during startup.
// //
// We don't keep full history of record boundaries in the page server, // We don't keep full history of record boundaries in the page server,
// however, only the predecessor of the latest record on each // however, only the predecessor of the latest record on each
@@ -751,34 +751,39 @@ where
// //
// Add generated pg_control file and bootstrap WAL segment. // Add generated pg_control file and bootstrap WAL segment.
// Also send zenith.signal file with extra bootstrap data. // Also send neon.signal and zenith.signal file with extra bootstrap data.
// //
async fn add_pgcontrol_file( async fn add_pgcontrol_file(
&mut self, &mut self,
pg_control_bytes: Bytes, pg_control_bytes: Bytes,
system_identifier: u64, system_identifier: u64,
) -> Result<(), BasebackupError> { ) -> Result<(), BasebackupError> {
// add zenith.signal file // add neon.signal file
let mut zenith_signal = String::new(); let mut neon_signal = String::new();
if self.prev_record_lsn == Lsn(0) { if self.prev_record_lsn == Lsn(0) {
if self.timeline.is_ancestor_lsn(self.lsn) { if self.timeline.is_ancestor_lsn(self.lsn) {
write!(zenith_signal, "PREV LSN: none") write!(neon_signal, "PREV LSN: none")
.map_err(|e| BasebackupError::Server(e.into()))?; .map_err(|e| BasebackupError::Server(e.into()))?;
} else { } else {
write!(zenith_signal, "PREV LSN: invalid") write!(neon_signal, "PREV LSN: invalid")
.map_err(|e| BasebackupError::Server(e.into()))?; .map_err(|e| BasebackupError::Server(e.into()))?;
} }
} else { } else {
write!(zenith_signal, "PREV LSN: {}", self.prev_record_lsn) write!(neon_signal, "PREV LSN: {}", self.prev_record_lsn)
.map_err(|e| BasebackupError::Server(e.into()))?; .map_err(|e| BasebackupError::Server(e.into()))?;
} }
self.ar
.append( // TODO: Remove zenith.signal once all historical computes have been replaced
&new_tar_header("zenith.signal", zenith_signal.len() as u64)?, // ... and thus support the neon.signal file.
zenith_signal.as_bytes(), for signalfilename in ["neon.signal", "zenith.signal"] {
) self.ar
.await .append(
.map_err(|e| BasebackupError::Client(e, "add_pgcontrol_file,zenith.signal"))?; &new_tar_header(signalfilename, neon_signal.len() as u64)?,
neon_signal.as_bytes(),
)
.await
.map_err(|e| BasebackupError::Client(e, "add_pgcontrol_file,neon.signal"))?;
}
//send pg_control //send pg_control
let header = new_tar_header("global/pg_control", pg_control_bytes.len() as u64)?; let header = new_tar_header("global/pg_control", pg_control_bytes.len() as u64)?;

View File

@@ -917,11 +917,6 @@ async fn create_remote_storage_client(
// If `test_remote_failures` is non-zero, wrap the client with a // If `test_remote_failures` is non-zero, wrap the client with a
// wrapper that simulates failures. // wrapper that simulates failures.
if conf.test_remote_failures > 0 { if conf.test_remote_failures > 0 {
if !cfg!(feature = "testing") {
anyhow::bail!(
"test_remote_failures option is not available because pageserver was compiled without the 'testing' feature"
);
}
info!( info!(
"Simulating remote failures for first {} attempts of each op", "Simulating remote failures for first {} attempts of each op",
conf.test_remote_failures conf.test_remote_failures

View File

@@ -4187,7 +4187,7 @@ pub fn make_router(
}) })
.get( .get(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/getpage", "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/getpage",
|r| testing_api_handler("getpage@lsn", r, getpage_at_lsn_handler), |r| testing_api_handler("getpage@lsn", r, getpage_at_lsn_handler),
) )
.get( .get(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/touchpage", "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/touchpage",

View File

@@ -610,13 +610,13 @@ async fn import_file(
debug!("imported twophase file"); debug!("imported twophase file");
} else if file_path.starts_with("pg_wal") { } else if file_path.starts_with("pg_wal") {
debug!("found wal file in base section. ignore it"); debug!("found wal file in base section. ignore it");
} else if file_path.starts_with("zenith.signal") { } else if file_path.starts_with("zenith.signal") || file_path.starts_with("neon.signal") {
// Parse zenith signal file to set correct previous LSN // Parse zenith signal file to set correct previous LSN
let bytes = read_all_bytes(reader).await?; let bytes = read_all_bytes(reader).await?;
// zenith.signal format is "PREV LSN: prev_lsn" // neon.signal format is "PREV LSN: prev_lsn"
// TODO write serialization and deserialization in the same place. // TODO write serialization and deserialization in the same place.
let zenith_signal = std::str::from_utf8(&bytes)?.trim(); let neon_signal = std::str::from_utf8(&bytes)?.trim();
let prev_lsn = match zenith_signal { let prev_lsn = match neon_signal {
"PREV LSN: none" => Lsn(0), "PREV LSN: none" => Lsn(0),
"PREV LSN: invalid" => Lsn(0), "PREV LSN: invalid" => Lsn(0),
other => { other => {
@@ -624,17 +624,17 @@ async fn import_file(
split[1] split[1]
.trim() .trim()
.parse::<Lsn>() .parse::<Lsn>()
.context("can't parse zenith.signal")? .context("can't parse neon.signal")?
} }
}; };
// zenith.signal is not necessarily the last file, that we handle // neon.signal is not necessarily the last file, that we handle
// but it is ok to call `finish_write()`, because final `modification.commit()` // but it is ok to call `finish_write()`, because final `modification.commit()`
// will update lsn once more to the final one. // will update lsn once more to the final one.
let writer = modification.tline.writer().await; let writer = modification.tline.writer().await;
writer.finish_write(prev_lsn); writer.finish_write(prev_lsn);
debug!("imported zenith signal {}", prev_lsn); debug!("imported neon signal {}", prev_lsn);
} else if file_path.starts_with("pg_tblspc") { } else if file_path.starts_with("pg_tblspc") {
// TODO Backups exported from neon won't have pg_tblspc, but we will need // TODO Backups exported from neon won't have pg_tblspc, but we will need
// this to import arbitrary postgres databases. // this to import arbitrary postgres databases.

View File

@@ -1678,6 +1678,8 @@ impl TenantManager {
// Phase 6: Release the InProgress on the parent shard // Phase 6: Release the InProgress on the parent shard
drop(parent_slot_guard); drop(parent_slot_guard);
utils::pausable_failpoint!("shard-split-post-finish-pause");
Ok(child_shards) Ok(child_shards)
} }

View File

@@ -5651,10 +5651,11 @@ impl Timeline {
/// Predicate function which indicates whether we should check if new image layers /// Predicate function which indicates whether we should check if new image layers
/// are required. Since checking if new image layers are required is expensive in /// are required. Since checking if new image layers are required is expensive in
/// terms of CPU, we only do it in the following cases: /// terms of CPU, we only do it in the following cases:
/// 1. If the timeline has ingested sufficient WAL to justify the cost /// 1. If the timeline has ingested sufficient WAL to justify the cost or ...
/// 2. If enough time has passed since the last check: /// 2. If enough time has passed since the last check:
/// 1. For large tenants, we wish to perform the check more often since they /// 1. For large tenants, we wish to perform the check more often since they
/// suffer from the lack of image layers /// suffer from the lack of image layers. Note that we assume sharded tenants
/// to be large since non-zero shards do not track the logical size.
/// 2. For small tenants (that can mostly fit in RAM), we use a much longer interval /// 2. For small tenants (that can mostly fit in RAM), we use a much longer interval
fn should_check_if_image_layers_required(self: &Arc<Timeline>, lsn: Lsn) -> bool { fn should_check_if_image_layers_required(self: &Arc<Timeline>, lsn: Lsn) -> bool {
let large_timeline_threshold = self.conf.image_layer_generation_large_timeline_threshold; let large_timeline_threshold = self.conf.image_layer_generation_large_timeline_threshold;
@@ -5668,30 +5669,39 @@ impl Timeline {
let distance_based_decision = distance.0 >= min_distance; let distance_based_decision = distance.0 >= min_distance;
let mut time_based_decision = false;
let mut last_check_instant = self.last_image_layer_creation_check_instant.lock().unwrap(); let mut last_check_instant = self.last_image_layer_creation_check_instant.lock().unwrap();
if let CurrentLogicalSize::Exact(logical_size) = self.current_logical_size.current_size() { let check_required_after = (|| {
let check_required_after = if self.shard_identity.is_unsharded() {
if Some(Into::<u64>::into(&logical_size)) >= large_timeline_threshold { if let CurrentLogicalSize::Exact(logical_size) =
self.get_checkpoint_timeout() self.current_logical_size.current_size()
} else { {
Duration::from_secs(3600 * 48) if Some(Into::<u64>::into(&logical_size)) < large_timeline_threshold {
}; return Duration::from_secs(3600 * 48);
}
time_based_decision = match *last_check_instant {
Some(last_check) => {
let elapsed = last_check.elapsed();
elapsed >= check_required_after
} }
None => true, }
};
} self.get_checkpoint_timeout()
})();
let time_based_decision = match *last_check_instant {
Some(last_check) => {
let elapsed = last_check.elapsed();
elapsed >= check_required_after
}
None => true,
};
// Do the expensive delta layer counting only if this timeline has ingested sufficient // Do the expensive delta layer counting only if this timeline has ingested sufficient
// WAL since the last check or a checkpoint timeout interval has elapsed since the last // WAL since the last check or a checkpoint timeout interval has elapsed since the last
// check. // check.
let decision = distance_based_decision || time_based_decision; let decision = distance_based_decision || time_based_decision;
tracing::info!(
"Decided to check image layers: {}. Distance-based decision: {}, time-based decision: {}",
decision,
distance_based_decision,
time_based_decision
);
if decision { if decision {
self.last_image_layer_creation_check_at.store(lsn); self.last_image_layer_creation_check_at.store(lsn);
*last_check_instant = Some(Instant::now()); *last_check_instant = Some(Instant::now());

View File

@@ -359,14 +359,14 @@ impl<T: Types> Cache<T> {
Err(e) => { Err(e) => {
// Retry on tenant manager error to handle tenant split more gracefully // Retry on tenant manager error to handle tenant split more gracefully
if attempt < GET_MAX_RETRIES { if attempt < GET_MAX_RETRIES {
tracing::warn!(
"Fail to resolve tenant shard in attempt {}: {:?}. Retrying...",
attempt,
e
);
tokio::time::sleep(RETRY_BACKOFF).await; tokio::time::sleep(RETRY_BACKOFF).await;
continue; continue;
} else { } else {
tracing::warn!(
"Failed to resolve tenant shard after {} attempts: {:?}",
GET_MAX_RETRIES,
e
);
return Err(e); return Err(e);
} }
} }

View File

@@ -147,6 +147,16 @@ pub enum RedoAttemptType {
GcCompaction, GcCompaction,
} }
impl std::fmt::Display for RedoAttemptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RedoAttemptType::ReadPage => write!(f, "read page"),
RedoAttemptType::LegacyCompaction => write!(f, "legacy compaction"),
RedoAttemptType::GcCompaction => write!(f, "gc compaction"),
}
}
}
/// ///
/// Public interface of WAL redo manager /// Public interface of WAL redo manager
/// ///
@@ -199,6 +209,7 @@ impl PostgresRedoManager {
self.conf.wal_redo_timeout, self.conf.wal_redo_timeout,
pg_version, pg_version,
max_retry_attempts, max_retry_attempts,
redo_attempt_type,
) )
.await .await
}; };
@@ -221,6 +232,7 @@ impl PostgresRedoManager {
self.conf.wal_redo_timeout, self.conf.wal_redo_timeout,
pg_version, pg_version,
max_retry_attempts, max_retry_attempts,
redo_attempt_type,
) )
.await .await
} }
@@ -445,6 +457,7 @@ impl PostgresRedoManager {
wal_redo_timeout: Duration, wal_redo_timeout: Duration,
pg_version: PgMajorVersion, pg_version: PgMajorVersion,
max_retry_attempts: u32, max_retry_attempts: u32,
redo_attempt_type: RedoAttemptType,
) -> Result<Bytes, Error> { ) -> Result<Bytes, Error> {
*(self.last_redo_at.lock().unwrap()) = Some(Instant::now()); *(self.last_redo_at.lock().unwrap()) = Some(Instant::now());
@@ -485,17 +498,28 @@ impl PostgresRedoManager {
); );
if let Err(e) = result.as_ref() { if let Err(e) = result.as_ref() {
error!( macro_rules! message {
"error applying {} WAL records {}..{} ({} bytes) to key {key}, from base image with LSN {} to reconstruct page image at LSN {} n_attempts={}: {:?}", ($level:tt) => {
records.len(), $level!(
records.first().map(|p| p.0).unwrap_or(Lsn(0)), "error applying {} WAL records {}..{} ({} bytes) to key {} during {}, from base image with LSN {} to reconstruct page image at LSN {} n_attempts={}: {:?}",
records.last().map(|p| p.0).unwrap_or(Lsn(0)), records.len(),
nbytes, records.first().map(|p| p.0).unwrap_or(Lsn(0)),
base_img_lsn, records.last().map(|p| p.0).unwrap_or(Lsn(0)),
lsn, nbytes,
n_attempts, key,
e, redo_attempt_type,
); base_img_lsn,
lsn,
n_attempts,
e,
)
}
}
match redo_attempt_type {
RedoAttemptType::ReadPage => message!(error),
RedoAttemptType::LegacyCompaction => message!(error),
RedoAttemptType::GcCompaction => message!(warn),
}
} }
result.map_err(Error::Other) result.map_err(Error::Other)

View File

@@ -162,8 +162,34 @@ typedef struct FileCacheControl
dlist_head lru; /* double linked list for LRU replacement dlist_head lru; /* double linked list for LRU replacement
* algorithm */ * algorithm */
dlist_head holes; /* double linked list of punched holes */ dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */ ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
/*
* Estimation of working set size.
*
* This is not guarded by the lock. No locking is needed because all the
* writes to the "registers" are simple 64-bit stores, to update a
* timestamp. We assume that:
*
* - 64-bit stores are atomic. We could enforce that by using
* pg_atomic_uint64 instead of TimestampTz as the datatype in hll.h, but
* for now we just rely on it implicitly.
*
* - Even if they're not, and there is a race between two stores, it
* doesn't matter much which one wins because they're both updating the
* register with the current timestamp. Or you have a race between
* resetting the register and updating it, in which case it also doesn't
* matter much which one wins.
*
* - If they're not atomic, you might get an occasional "torn write" if
* you're really unlucky, but we tolerate that too. It just means that
* the estimate will be a little off, until the register is updated
* again.
*/
HyperLogLogState wss_estimation;
/* Prewarmer state */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS]; PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
size_t n_prewarm_workers; size_t n_prewarm_workers;
size_t n_prewarm_entries; size_t n_prewarm_entries;
@@ -205,6 +231,8 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0) #define LFC_ENABLED() (lfc_ctl->limit != 0)
PGDLLEXPORT void lfc_prewarm_main(Datum main_arg);
/* /*
* Close LFC file if opened. * Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled. * All backends should close their LFC files once LFC is disabled.
@@ -1142,6 +1170,13 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
/* Update working set size estimate for the blocks */
for (int i = 0; i < nblocks; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
/* /*
* For every chunk that has blocks we're interested in, we * For every chunk that has blocks we're interested in, we
* 1. get the chunk header * 1. get the chunk header
@@ -1220,14 +1255,6 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
} }
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL); entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (entry == NULL) if (entry == NULL)
{ {
/* Pages are not cached */ /* Pages are not cached */
@@ -1504,9 +1531,15 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false; return false;
CopyNRelFileInfoToBufTag(tag, rinfo); CopyNRelFileInfoToBufTag(tag, rinfo);
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.forkNum = forknum; tag.forkNum = forknum;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); /* Update working set size estimate for the blocks */
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
tag.blockNum = blkno - chunk_offs; tag.blockNum = blkno - chunk_offs;
hash = get_hash_value(lfc_hash, &tag); hash = get_hash_value(lfc_hash, &tag);
@@ -1524,19 +1557,13 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
if (lwlsn > lsn) if (lwlsn > lsn)
{ {
elog(DEBUG1, "Skip LFC write for %d because LwLSN=%X/%X is greater than not_nodified_since LSN %X/%X", elog(DEBUG1, "Skip LFC write for %u because LwLSN=%X/%X is greater than not_nodified_since LSN %X/%X",
blkno, LSN_FORMAT_ARGS(lwlsn), LSN_FORMAT_ARGS(lsn)); blkno, LSN_FORMAT_ARGS(lwlsn), LSN_FORMAT_ARGS(lsn));
LWLockRelease(lfc_lock); LWLockRelease(lfc_lock);
return false; return false;
} }
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found); entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found) if (found)
{ {
state = GET_STATE(entry, chunk_offs); state = GET_STATE(entry, chunk_offs);
@@ -1649,9 +1676,15 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
return; return;
CopyNRelFileInfoToBufTag(tag, rinfo); CopyNRelFileInfoToBufTag(tag, rinfo);
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.forkNum = forkNum; tag.forkNum = forkNum;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber); /* Update working set size estimate for the blocks */
for (int i = 0; i < nblocks; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE); LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
@@ -1692,14 +1725,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
cv = &lfc_ctl->cv[hash % N_COND_VARS]; cv = &lfc_ctl->cv[hash % N_COND_VARS];
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found); entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found) if (found)
{ {
/* /*
@@ -2135,40 +2160,23 @@ local_cache_pages(PG_FUNCTION_ARGS)
SRF_RETURN_DONE(funcctx); SRF_RETURN_DONE(funcctx);
} }
PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds);
Datum /*
approximate_working_set_size_seconds(PG_FUNCTION_ARGS) * Internal implementation of the approximate_working_set_size_seconds()
* function.
*/
int32
lfc_approximate_working_set_size_seconds(time_t duration, bool reset)
{ {
if (lfc_size_limit != 0) int32 dc;
{
int32 dc;
time_t duration = PG_ARGISNULL(0) ? (time_t)-1 : PG_GETARG_INT32(0);
LWLockAcquire(lfc_lock, LW_SHARED);
dc = (int32) estimateSHLL(&lfc_ctl->wss_estimation, duration);
LWLockRelease(lfc_lock);
PG_RETURN_INT32(dc);
}
PG_RETURN_NULL();
}
PG_FUNCTION_INFO_V1(approximate_working_set_size); if (lfc_size_limit == 0)
return -1;
Datum dc = (int32) estimateSHLL(&lfc_ctl->wss_estimation, duration);
approximate_working_set_size(PG_FUNCTION_ARGS) if (reset)
{ memset(lfc_ctl->wss_estimation.regs, 0, sizeof lfc_ctl->wss_estimation.regs);
if (lfc_size_limit != 0) return dc;
{
int32 dc;
bool reset = PG_GETARG_BOOL(0);
LWLockAcquire(lfc_lock, reset ? LW_EXCLUSIVE : LW_SHARED);
dc = (int32) estimateSHLL(&lfc_ctl->wss_estimation, (time_t)-1);
if (reset)
memset(lfc_ctl->wss_estimation.regs, 0, sizeof lfc_ctl->wss_estimation.regs);
LWLockRelease(lfc_lock);
PG_RETURN_INT32(dc);
}
PG_RETURN_NULL();
} }
PG_FUNCTION_INFO_V1(get_local_cache_state); PG_FUNCTION_INFO_V1(get_local_cache_state);

View File

@@ -47,7 +47,8 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk
extern FileCacheState* lfc_get_state(size_t max_entries); extern FileCacheState* lfc_get_state(size_t max_entries);
extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers); extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers);
PGDLLEXPORT void lfc_prewarm_main(Datum main_arg); extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset);
static inline bool static inline bool
lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,

View File

@@ -561,6 +561,8 @@ _PG_init(void)
PG_FUNCTION_INFO_V1(pg_cluster_size); PG_FUNCTION_INFO_V1(pg_cluster_size);
PG_FUNCTION_INFO_V1(backpressure_lsns); PG_FUNCTION_INFO_V1(backpressure_lsns);
PG_FUNCTION_INFO_V1(backpressure_throttling_time); PG_FUNCTION_INFO_V1(backpressure_throttling_time);
PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds);
PG_FUNCTION_INFO_V1(approximate_working_set_size);
Datum Datum
pg_cluster_size(PG_FUNCTION_ARGS) pg_cluster_size(PG_FUNCTION_ARGS)
@@ -607,6 +609,34 @@ backpressure_throttling_time(PG_FUNCTION_ARGS)
PG_RETURN_UINT64(BackpressureThrottlingTime()); PG_RETURN_UINT64(BackpressureThrottlingTime());
} }
Datum
approximate_working_set_size_seconds(PG_FUNCTION_ARGS)
{
time_t duration;
int32 dc;
duration = PG_ARGISNULL(0) ? (time_t) -1 : PG_GETARG_INT32(0);
dc = lfc_approximate_working_set_size_seconds(duration, false);
if (dc < 0)
PG_RETURN_NULL();
else
PG_RETURN_INT32(dc);
}
Datum
approximate_working_set_size(PG_FUNCTION_ARGS)
{
bool reset = PG_GETARG_BOOL(0);
int32 dc;
dc = lfc_approximate_working_set_size_seconds(-1, reset);
if (dc < 0)
PG_RETURN_NULL();
else
PG_RETURN_INT32(dc);
}
#if PG_MAJORVERSION_NUM >= 16 #if PG_MAJORVERSION_NUM >= 16
static void static void
neon_shmem_startup_hook(void) neon_shmem_startup_hook(void)

View File

@@ -236,13 +236,13 @@ clear_buffer_cache(PG_FUNCTION_ARGS)
bool save_neon_test_evict; bool save_neon_test_evict;
/* /*
* Temporarily set the zenith_test_evict GUC, so that when we pin and * Temporarily set the neon_test_evict GUC, so that when we pin and
* unpin a buffer, the buffer is evicted. We use that hack to evict all * unpin a buffer, the buffer is evicted. We use that hack to evict all
* buffers, as there is no explicit "evict this buffer" function in the * buffers, as there is no explicit "evict this buffer" function in the
* buffer manager. * buffer manager.
*/ */
save_neon_test_evict = zenith_test_evict; save_neon_test_evict = neon_test_evict;
zenith_test_evict = true; neon_test_evict = true;
PG_TRY(); PG_TRY();
{ {
/* Scan through all the buffers */ /* Scan through all the buffers */
@@ -273,7 +273,7 @@ clear_buffer_cache(PG_FUNCTION_ARGS)
/* /*
* Pin the buffer, and release it again. Because we have * Pin the buffer, and release it again. Because we have
* zenith_test_evict==true, this will evict the page from the * neon_test_evict==true, this will evict the page from the
* buffer cache if no one else is holding a pin on it. * buffer cache if no one else is holding a pin on it.
*/ */
if (isvalid) if (isvalid)
@@ -286,7 +286,7 @@ clear_buffer_cache(PG_FUNCTION_ARGS)
PG_FINALLY(); PG_FINALLY();
{ {
/* restore the GUC */ /* restore the GUC */
zenith_test_evict = save_neon_test_evict; neon_test_evict = save_neon_test_evict;
} }
PG_END_TRY(); PG_END_TRY();

View File

@@ -2953,17 +2953,17 @@ XmlTableBuilderData
YYLTYPE YYLTYPE
YYSTYPE YYSTYPE
YY_BUFFER_STATE YY_BUFFER_STATE
ZenithErrorResponse NeonErrorResponse
ZenithExistsRequest NeonExistsRequest
ZenithExistsResponse NeonExistsResponse
ZenithGetPageRequest NeonGetPageRequest
ZenithGetPageResponse NeonGetPageResponse
ZenithMessage NeonMessage
ZenithMessageTag NeonMessageTag
ZenithNblocksRequest NeonNblocksRequest
ZenithNblocksResponse NeonNblocksResponse
ZenithRequest NeonRequest
ZenithResponse NeonResponse
_SPI_connection _SPI_connection
_SPI_plan _SPI_plan
__AssignProcessToJobObject __AssignProcessToJobObject

View File

@@ -16,6 +16,7 @@ async-compression.workspace = true
async-trait.workspace = true async-trait.workspace = true
atomic-take.workspace = true atomic-take.workspace = true
aws-config.workspace = true aws-config.workspace = true
aws-credential-types.workspace = true
aws-sdk-iam.workspace = true aws-sdk-iam.workspace = true
aws-sigv4.workspace = true aws-sigv4.workspace = true
base64.workspace = true base64.workspace = true
@@ -48,6 +49,7 @@ indexmap = { workspace = true, features = ["serde"] }
ipnet.workspace = true ipnet.workspace = true
itertools.workspace = true itertools.workspace = true
itoa.workspace = true itoa.workspace = true
json = { path = "../libs/proxy/json" }
lasso = { workspace = true, features = ["multi-threaded"] } lasso = { workspace = true, features = ["multi-threaded"] }
measured = { workspace = true, features = ["lasso"] } measured = { workspace = true, features = ["lasso"] }
metrics.workspace = true metrics.workspace = true
@@ -127,4 +129,4 @@ rstest.workspace = true
walkdir.workspace = true walkdir.workspace = true
rand_distr = "0.4" rand_distr = "0.4"
tokio-postgres.workspace = true tokio-postgres.workspace = true
tracing-test = "0.2" tracing-test = "0.2"

View File

@@ -123,6 +123,11 @@ docker exec -it proxy-postgres psql -U postgres -c "CREATE TABLE neon_control_pl
docker exec -it proxy-postgres psql -U postgres -c "CREATE ROLE proxy WITH SUPERUSER LOGIN PASSWORD 'password';" docker exec -it proxy-postgres psql -U postgres -c "CREATE ROLE proxy WITH SUPERUSER LOGIN PASSWORD 'password';"
``` ```
If you want to test query cancellation, redis is also required:
```sh
docker run --detach --name proxy-redis --publish 6379:6379 redis:7.0
```
Let's create self-signed certificate by running: Let's create self-signed certificate by running:
```sh ```sh
openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.local.neon.build" openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key -subj "/CN=*.local.neon.build"
@@ -130,7 +135,10 @@ openssl req -new -x509 -days 365 -nodes -text -out server.crt -keyout server.key
Then we need to build proxy with 'testing' feature and run, e.g.: Then we need to build proxy with 'testing' feature and run, e.g.:
```sh ```sh
RUST_LOG=proxy LOGFMT=text cargo run -p proxy --bin proxy --features testing -- --auth-backend postgres --auth-endpoint 'postgresql://postgres:proxy-postgres@127.0.0.1:5432/postgres' -c server.crt -k server.key RUST_LOG=proxy LOGFMT=text cargo run -p proxy --bin proxy --features testing -- \
--auth-backend postgres --auth-endpoint 'postgresql://postgres:proxy-postgres@127.0.0.1:5432/postgres' \
--redis-auth-type="plain" --redis-plain="redis://127.0.0.1:6379" \
-c server.crt -k server.key
``` ```
Now from client you can start a new session: Now from client you can start a new session:

View File

@@ -7,13 +7,17 @@ use std::pin::pin;
use std::sync::Mutex; use std::sync::Mutex;
use scopeguard::ScopeGuard; use scopeguard::ScopeGuard;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError; use tokio::sync::oneshot::error::TryRecvError;
use crate::ext::LockExt; use crate::ext::LockExt;
type ProcResult<P> = Result<<P as QueueProcessing>::Res, <P as QueueProcessing>::Err>;
pub trait QueueProcessing: Send + 'static { pub trait QueueProcessing: Send + 'static {
type Req: Send + 'static; type Req: Send + 'static;
type Res: Send; type Res: Send;
type Err: Send + Clone;
/// Get the desired batch size. /// Get the desired batch size.
fn batch_size(&self, queue_size: usize) -> usize; fn batch_size(&self, queue_size: usize) -> usize;
@@ -24,7 +28,18 @@ pub trait QueueProcessing: Send + 'static {
/// If this apply can error, it's expected that errors be forwarded to each Self::Res. /// If this apply can error, it's expected that errors be forwarded to each Self::Res.
/// ///
/// Batching does not need to happen atomically. /// Batching does not need to happen atomically.
fn apply(&mut self, req: Vec<Self::Req>) -> impl Future<Output = Vec<Self::Res>> + Send; fn apply(
&mut self,
req: Vec<Self::Req>,
) -> impl Future<Output = Result<Vec<Self::Res>, Self::Err>> + Send;
}
#[derive(thiserror::Error)]
pub enum BatchQueueError<E: Clone, C> {
#[error(transparent)]
Result(E),
#[error(transparent)]
Cancelled(C),
} }
pub struct BatchQueue<P: QueueProcessing> { pub struct BatchQueue<P: QueueProcessing> {
@@ -34,7 +49,7 @@ pub struct BatchQueue<P: QueueProcessing> {
struct BatchJob<P: QueueProcessing> { struct BatchJob<P: QueueProcessing> {
req: P::Req, req: P::Req,
res: tokio::sync::oneshot::Sender<P::Res>, res: tokio::sync::oneshot::Sender<Result<P::Res, P::Err>>,
} }
impl<P: QueueProcessing> BatchQueue<P> { impl<P: QueueProcessing> BatchQueue<P> {
@@ -55,11 +70,11 @@ impl<P: QueueProcessing> BatchQueue<P> {
&self, &self,
req: P::Req, req: P::Req,
cancelled: impl Future<Output = R>, cancelled: impl Future<Output = R>,
) -> Result<P::Res, R> { ) -> Result<P::Res, BatchQueueError<P::Err, R>> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let mut cancelled = pin!(cancelled); let mut cancelled = pin!(cancelled);
let resp = loop { let resp: Option<Result<P::Res, P::Err>> = loop {
// try become the leader, or try wait for success. // try become the leader, or try wait for success.
let mut processor = tokio::select! { let mut processor = tokio::select! {
// try become leader. // try become leader.
@@ -72,7 +87,7 @@ impl<P: QueueProcessing> BatchQueue<P> {
if inner.queue.remove(&id).is_some() { if inner.queue.remove(&id).is_some() {
tracing::warn!("batched task cancelled before completion"); tracing::warn!("batched task cancelled before completion");
} }
return Err(cancel); return Err(BatchQueueError::Cancelled(cancel));
}, },
}; };
@@ -96,18 +111,30 @@ impl<P: QueueProcessing> BatchQueue<P> {
// good: we didn't get cancelled. // good: we didn't get cancelled.
ScopeGuard::into_inner(cancel_safety); ScopeGuard::into_inner(cancel_safety);
if values.len() != resps.len() { match values {
tracing::error!( Ok(values) => {
"batch: invalid response size, expected={}, got={}", if values.len() != resps.len() {
resps.len(), tracing::error!(
values.len() "batch: invalid response size, expected={}, got={}",
); resps.len(),
} values.len()
);
}
// send response values. // send response values.
for (tx, value) in std::iter::zip(resps, values) { for (tx, value) in std::iter::zip(resps, values) {
if tx.send(value).is_err() { if tx.send(Ok(value)).is_err() {
// receiver hung up but that's fine. // receiver hung up but that's fine.
}
}
}
Err(err) => {
for tx in resps {
if tx.send(Err(err.clone())).is_err() {
// receiver hung up but that's fine.
}
}
} }
} }
@@ -129,7 +156,8 @@ impl<P: QueueProcessing> BatchQueue<P> {
tracing::debug!(id, "batch: job completed"); tracing::debug!(id, "batch: job completed");
Ok(resp.expect("no response found. batch processer should not panic")) resp.expect("no response found. batch processer should not panic")
.map_err(BatchQueueError::Result)
} }
} }
@@ -139,8 +167,8 @@ struct BatchQueueInner<P: QueueProcessing> {
} }
impl<P: QueueProcessing> BatchQueueInner<P> { impl<P: QueueProcessing> BatchQueueInner<P> {
fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver<P::Res>) { fn register_job(&mut self, req: P::Req) -> (u64, oneshot::Receiver<ProcResult<P>>) {
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = oneshot::channel();
let id = self.version; let id = self.version;
@@ -158,7 +186,7 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
(id, rx) (id, rx)
} }
fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<tokio::sync::oneshot::Sender<P::Res>>) { fn get_batch(&mut self, p: &P) -> (Vec<P::Req>, Vec<oneshot::Sender<ProcResult<P>>>) {
let batch_size = p.batch_size(self.queue.len()); let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size); let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size); let mut resps = Vec::with_capacity(batch_size);

View File

@@ -522,15 +522,7 @@ pub async fn run() -> anyhow::Result<()> {
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
} }
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend if let Some(client) = redis_client {
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval. // Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart, // This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues. // which can cause hammering of the redis in case of connection issues.
@@ -560,6 +552,16 @@ pub async fn run() -> anyhow::Result<()> {
} }
} }
} }
#[allow(irrefutable_let_patterns)]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client, cache.clone()));
maintenance_tasks.spawn(async move { cache.gc_worker().await });
}
} }
let maintenance = loop { let maintenance = loop {

View File

@@ -4,12 +4,11 @@ use std::pin::pin;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use std::time::Duration; use std::time::Duration;
use anyhow::anyhow;
use futures::FutureExt; use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::RawCancelToken; use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect; use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value}; use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -18,7 +17,7 @@ use tracing::{debug, error, info};
use crate::auth::AuthError; use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::ComputeUserInfo;
use crate::batch::{BatchQueue, QueueProcessing}; use crate::batch::{BatchQueue, BatchQueueError, QueueProcessing};
use crate::config::ComputeConfig; use crate::config::ComputeConfig;
use crate::context::RequestContext; use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi; use crate::control_plane::ControlPlaneApi;
@@ -28,23 +27,39 @@ use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, Redis
use crate::pqproto::CancelKeyData; use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter; use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix; use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient; use crate::redis::kv_ops::{RedisKVClient, RedisKVClientError};
use crate::util::run_until;
type IpSubnetKey = IpNet; type IpSubnetKey = IpNet;
const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600); const CANCEL_KEY_TTL: Duration = Duration::from_secs(600);
const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570); const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570);
// Message types for sending through mpsc channel // Message types for sending through mpsc channel
pub enum CancelKeyOp { pub enum CancelKeyOp {
StoreCancelKey { Store {
key: CancelKeyData, key: CancelKeyData,
value: Box<str>, value: Box<str>,
expire: std::time::Duration, expire: Duration,
}, },
GetCancelData { Refresh {
key: CancelKeyData,
expire: Duration,
},
Get {
key: CancelKeyData, key: CancelKeyData,
}, },
GetOld {
key: CancelKeyData,
},
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum PipelineError {
#[error("could not send cmd to redis: {0}")]
RedisKVClient(Arc<RedisKVClientError>),
#[error("incorrect number of responses from redis")]
IncorrectNumberOfResponses,
} }
pub struct Pipeline { pub struct Pipeline {
@@ -60,7 +75,7 @@ impl Pipeline {
} }
} }
async fn execute(self, client: &mut RedisKVClient) -> Vec<anyhow::Result<Value>> { async fn execute(self, client: &mut RedisKVClient) -> Result<Vec<Value>, PipelineError> {
let responses = self.replies; let responses = self.replies;
let batch_size = self.inner.len(); let batch_size = self.inner.len();
@@ -78,43 +93,44 @@ impl Pipeline {
batch_size, batch_size,
responses, "successfully completed cancellation jobs", responses, "successfully completed cancellation jobs",
); );
values.into_iter().map(Ok).collect() Ok(values.into_iter().collect())
} }
Ok(value) => { Ok(value) => {
error!(batch_size, ?value, "unexpected redis return value"); error!(batch_size, ?value, "unexpected redis return value");
std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis"))) Err(PipelineError::IncorrectNumberOfResponses)
.take(responses)
.collect()
}
Err(err) => {
std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
.take(responses)
.collect()
} }
Err(err) => Err(PipelineError::RedisKVClient(Arc::new(err))),
} }
} }
fn add_command_with_reply(&mut self, cmd: Cmd) { fn add_command(&mut self, cmd: Cmd) {
self.inner.add_command(cmd); self.inner.add_command(cmd);
self.replies += 1; self.replies += 1;
} }
fn add_command_no_reply(&mut self, cmd: Cmd) {
self.inner.add_command(cmd).ignore();
}
} }
impl CancelKeyOp { impl CancelKeyOp {
fn register(&self, pipe: &mut Pipeline) { fn register(&self, pipe: &mut Pipeline) {
match self { match self {
CancelKeyOp::StoreCancelKey { key, value, expire } => { CancelKeyOp::Store { key, value, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key(); let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value)); pipe.add_command(Cmd::set_options(
pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64)); &key,
&**value,
SetOptions::default().with_expiration(SetExpiry::EX(expire.as_secs())),
));
} }
CancelKeyOp::GetCancelData { key } => { CancelKeyOp::Refresh { key, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key(); let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command_with_reply(Cmd::hget(key, "data")); pipe.add_command(Cmd::expire(&key, expire.as_secs() as i64));
}
CancelKeyOp::GetOld { key } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command(Cmd::hget(key, "data"));
}
CancelKeyOp::Get { key } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command(Cmd::get(key));
} }
} }
} }
@@ -127,13 +143,14 @@ pub struct CancellationProcessor {
impl QueueProcessing for CancellationProcessor { impl QueueProcessing for CancellationProcessor {
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp); type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
type Res = anyhow::Result<redis::Value>; type Res = redis::Value;
type Err = PipelineError;
fn batch_size(&self, _queue_size: usize) -> usize { fn batch_size(&self, _queue_size: usize) -> usize {
self.batch_size self.batch_size
} }
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> { async fn apply(&mut self, batch: Vec<Self::Req>) -> Result<Vec<Self::Res>, Self::Err> {
if !self.client.credentials_refreshed() { if !self.client.credentials_refreshed() {
// this will cause a timeout for cancellation operations // this will cause a timeout for cancellation operations
tracing::debug!( tracing::debug!(
@@ -244,18 +261,18 @@ impl CancellationHandler {
&self, &self,
key: CancelKeyData, key: CancelKeyData,
) -> Result<Option<CancelClosure>, CancelError> { ) -> Result<Option<CancelClosure>, CancelError> {
let guard = Metrics::get() const TIMEOUT: Duration = Duration::from_secs(5);
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGet);
let op = CancelKeyOp::GetCancelData { key };
let Some(tx) = self.tx.get() else { let Some(tx) = self.tx.get() else {
tracing::warn!("cancellation handler is not available"); tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError); return Err(CancelError::InternalError);
}; };
const TIMEOUT: Duration = Duration::from_secs(5); let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Get);
let op = CancelKeyOp::Get { key };
let result = timeout( let result = timeout(
TIMEOUT, TIMEOUT,
tx.call((guard, op), std::future::pending::<Infallible>()), tx.call((guard, op), std::future::pending::<Infallible>()),
@@ -264,10 +281,37 @@ impl CancellationHandler {
.map_err(|_| { .map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response"); tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit CancelError::RateLimit
})? })?;
// cannot be cancelled
.unwrap_or_else(|x| match x {}) // We may still have cancel keys set with HSET <key> "data".
.map_err(|e| { // Check error type and retry with HGET.
// TODO: remove code after HSET is not used anymore.
let result = if let Err(err) = result.as_ref()
&& let BatchQueueError::Result(err) = err
&& let PipelineError::RedisKVClient(err) = err
&& let RedisKVClientError::Redis(err) = &**err
&& let Some(errcode) = err.code()
&& errcode == "WRONGTYPE"
{
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGet);
let op = CancelKeyOp::GetOld { key };
timeout(
TIMEOUT,
tx.call((guard, op), std::future::pending::<Infallible>()),
)
.await
.map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit
})?
} else {
result
};
let result = result.map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}"); tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError CancelError::InternalError
})?; })?;
@@ -438,39 +482,94 @@ impl Session {
let mut cancel = pin!(cancel); let mut cancel = pin!(cancel);
enum State {
Set,
Refresh,
}
let mut state = State::Set;
loop { loop {
let guard = Metrics::get() let guard_op = match state {
.proxy State::Set => {
.cancel_channel_size let guard = Metrics::get()
.guard(RedisMsgKind::HSet); .proxy
let op = CancelKeyOp::StoreCancelKey { .cancel_channel_size
key: self.key, .guard(RedisMsgKind::Set);
value: closure_json.clone(), let op = CancelKeyOp::Store {
expire: CANCEL_KEY_TTL, key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
};
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registering cancellation key"
);
(guard, op)
}
State::Refresh => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Expire);
let op = CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_TTL,
};
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"refreshing cancellation key"
);
(guard, op)
}
}; };
tracing::debug!( match tx.call(guard_op, cancel.as_mut()).await {
src=%self.key, // SET returns OK
dest=?cancel_closure.cancel_token, Ok(Value::Okay) => {
"registering cancellation key"
);
match tx.call((guard, op), cancel.as_mut()).await {
Ok(Ok(_)) => {
tracing::debug!( tracing::debug!(
src=%self.key, src=%self.key,
dest=?cancel_closure.cancel_token, dest=?cancel_closure.cancel_token,
"registered cancellation key" "registered cancellation key"
); );
state = State::Refresh;
}
// wait before continuing. // EXPIRE returns 1
tokio::time::sleep(CANCEL_KEY_REFRESH).await; Ok(Value::Int(1)) => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"refreshed cancellation key"
);
} }
Ok(_) => {
// Any other response likely means the key expired.
tracing::warn!(src=%self.key, "refreshing cancellation key failed");
// Re-enter the SET loop to repush full data.
state = State::Set;
}
// retry immediately. // retry immediately.
Ok(Err(error)) => { Err(BatchQueueError::Result(error)) => {
tracing::warn!(?error, "error registering cancellation key"); tracing::warn!(?error, "error refreshing cancellation key");
// Small delay to prevent busy loop with high cpu and logging.
tokio::time::sleep(Duration::from_millis(10)).await;
continue;
} }
Err(Err(_cancelled)) => break,
Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
}
// wait before continuing. break immediately if cancelled.
if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut())
.await
.is_err()
{
break;
} }
} }

View File

@@ -374,11 +374,10 @@ pub enum Waiting {
#[label(singleton = "kind")] #[label(singleton = "kind")]
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
pub enum RedisMsgKind { pub enum RedisMsgKind {
HSet, Set,
HSetMultiple, Get,
Expire,
HGet, HGet,
HGetAll,
HDel,
} }
#[derive(Default, Clone)] #[derive(Default, Clone)]

View File

@@ -4,11 +4,12 @@ use std::time::Duration;
use futures::FutureExt; use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection}; use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult}; use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisError, RedisResult};
use tokio::task::AbortHandle; use tokio::task::AbortHandle;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider; use super::elasticache::CredentialsProvider;
use crate::redis::elasticache::CredentialsProviderError;
enum Credentials { enum Credentials {
Static(ConnectionInfo), Static(ConnectionInfo),
@@ -26,6 +27,14 @@ impl Clone for Credentials {
} }
} }
#[derive(thiserror::Error, Debug)]
pub enum ConnectionProviderError {
#[error(transparent)]
Redis(#[from] RedisError),
#[error(transparent)]
CredentialsProvider(#[from] CredentialsProviderError),
}
/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token. /// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
/// Provides PubSub connection without credentials refresh. /// Provides PubSub connection without credentials refresh.
pub struct ConnectionWithCredentialsProvider { pub struct ConnectionWithCredentialsProvider {
@@ -86,15 +95,18 @@ impl ConnectionWithCredentialsProvider {
} }
} }
async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> { async fn ping(con: &mut MultiplexedConnection) -> Result<(), ConnectionProviderError> {
redis::cmd("PING").query_async(con).await redis::cmd("PING")
.query_async(con)
.await
.map_err(Into::into)
} }
pub(crate) fn credentials_refreshed(&self) -> bool { pub(crate) fn credentials_refreshed(&self) -> bool {
self.credentials_refreshed.load(Ordering::Relaxed) self.credentials_refreshed.load(Ordering::Relaxed)
} }
pub(crate) async fn connect(&mut self) -> anyhow::Result<()> { pub(crate) async fn connect(&mut self) -> Result<(), ConnectionProviderError> {
let _guard = self.mutex.lock().await; let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() { if let Some(con) = self.con.as_mut() {
match Self::ping(con).await { match Self::ping(con).await {
@@ -141,7 +153,7 @@ impl ConnectionWithCredentialsProvider {
Ok(()) Ok(())
} }
async fn get_connection_info(&self) -> anyhow::Result<ConnectionInfo> { async fn get_connection_info(&self) -> Result<ConnectionInfo, ConnectionProviderError> {
match &self.credentials { match &self.credentials {
Credentials::Static(info) => Ok(info.clone()), Credentials::Static(info) => Ok(info.clone()),
Credentials::Dynamic(provider, addr) => { Credentials::Dynamic(provider, addr) => {
@@ -160,7 +172,7 @@ impl ConnectionWithCredentialsProvider {
} }
} }
async fn get_client(&self) -> anyhow::Result<redis::Client> { async fn get_client(&self) -> Result<redis::Client, ConnectionProviderError> {
let client = redis::Client::open(self.get_connection_info().await?)?; let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed); self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client) Ok(client)

View File

@@ -9,10 +9,12 @@ use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider; use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig; use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_credential_types::provider::error::CredentialsError;
use aws_sdk_iam::config::ProvideCredentials; use aws_sdk_iam::config::ProvideCredentials;
use aws_sigv4::http_request::{ use aws_sigv4::http_request::{
self, SignableBody, SignableRequest, SignatureLocation, SigningSettings, self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
}; };
use aws_sigv4::sign::v4::signing_params::BuildError;
use tracing::info; use tracing::info;
#[derive(Debug)] #[derive(Debug)]
@@ -40,6 +42,18 @@ impl AWSIRSAConfig {
} }
} }
#[derive(thiserror::Error, Debug)]
pub enum CredentialsProviderError {
#[error(transparent)]
AwsCredentials(#[from] CredentialsError),
#[error(transparent)]
AwsSigv4Build(#[from] BuildError),
#[error(transparent)]
AwsSigv4Singing(#[from] SigningError),
#[error(transparent)]
Http(#[from] http::Error),
}
/// Credentials provider for AWS elasticache authentication. /// Credentials provider for AWS elasticache authentication.
/// ///
/// Official documentation: /// Official documentation:
@@ -92,7 +106,9 @@ impl CredentialsProvider {
}) })
} }
pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { pub(crate) async fn provide_credentials(
&self,
) -> Result<(String, String), CredentialsProviderError> {
let aws_credentials = self let aws_credentials = self
.credentials_provider .credentials_provider
.provide_credentials() .provide_credentials()

View File

@@ -2,9 +2,18 @@ use std::time::Duration;
use futures::FutureExt; use futures::FutureExt;
use redis::aio::ConnectionLike; use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; use redis::{Cmd, FromRedisValue, Pipeline, RedisError, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::connection_with_credentials_provider::ConnectionProviderError;
#[derive(thiserror::Error, Debug)]
pub enum RedisKVClientError {
#[error(transparent)]
Redis(#[from] RedisError),
#[error(transparent)]
ConnectionProvider(#[from] ConnectionProviderError),
}
pub struct RedisKVClient { pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider, client: ConnectionWithCredentialsProvider,
@@ -32,12 +41,13 @@ impl RedisKVClient {
Self { client } Self { client }
} }
pub async fn try_connect(&mut self) -> anyhow::Result<()> { pub async fn try_connect(&mut self) -> Result<(), RedisKVClientError> {
self.client self.client
.connect() .connect()
.boxed() .boxed()
.await .await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}")) .inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
.map_err(Into::into)
} }
pub(crate) fn credentials_refreshed(&self) -> bool { pub(crate) fn credentials_refreshed(&self) -> bool {
@@ -47,7 +57,7 @@ impl RedisKVClient {
pub(crate) async fn query<T: FromRedisValue>( pub(crate) async fn query<T: FromRedisValue>(
&mut self, &mut self,
q: &impl Queryable, q: &impl Queryable,
) -> anyhow::Result<T> { ) -> Result<T, RedisKVClientError> {
let e = match q.query(&mut self.client).await { let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t), Ok(t) => return Ok(t),
Err(e) => e, Err(e) => e,

View File

@@ -1,6 +1,7 @@
use json::{ListSer, ObjectSer, ValueSer};
use postgres_client::Row; use postgres_client::Row;
use postgres_client::types::{Kind, Type}; use postgres_client::types::{Kind, Type};
use serde_json::{Map, Value}; use serde_json::Value;
// //
// Convert json non-string types to strings, so that they can be passed to Postgres // Convert json non-string types to strings, so that they can be passed to Postgres
@@ -74,44 +75,40 @@ pub(crate) enum JsonConversionError {
UnbalancedString, UnbalancedString,
} }
enum OutputMode { enum OutputMode<'a> {
Array(Vec<Value>), Array(ListSer<'a>),
Object(Map<String, Value>), Object(ObjectSer<'a>),
} }
impl OutputMode { impl OutputMode<'_> {
fn key(&mut self, key: &str) -> &mut Value { fn key(&mut self, key: &str) -> ValueSer<'_> {
match self { match self {
OutputMode::Array(values) => push_entry(values, Value::Null), OutputMode::Array(values) => values.entry(),
OutputMode::Object(map) => map.entry(key.to_string()).or_insert(Value::Null), OutputMode::Object(map) => map.key(key),
} }
} }
fn finish(self) -> Value { fn finish(self) {
match self { match self {
OutputMode::Array(values) => Value::Array(values), OutputMode::Array(values) => values.finish(),
OutputMode::Object(map) => Value::Object(map), OutputMode::Object(map) => map.finish(),
} }
} }
} }
fn push_entry<T>(arr: &mut Vec<T>, t: T) -> &mut T {
arr.push(t);
arr.last_mut().expect("a value was just inserted")
}
// //
// Convert postgres row with text-encoded values to JSON object // Convert postgres row with text-encoded values to JSON object
// //
pub(crate) fn pg_text_row_to_json( pub(crate) fn pg_text_row_to_json(
output: ValueSer,
row: &Row, row: &Row,
raw_output: bool, raw_output: bool,
array_mode: bool, array_mode: bool,
) -> Result<Value, JsonConversionError> { ) -> Result<(), JsonConversionError> {
let mut entries = if array_mode { let mut entries = if array_mode {
OutputMode::Array(Vec::with_capacity(row.columns().len())) OutputMode::Array(output.list())
} else { } else {
OutputMode::Object(Map::with_capacity(row.columns().len())) OutputMode::Object(output.object())
}; };
for (i, column) in row.columns().iter().enumerate() { for (i, column) in row.columns().iter().enumerate() {
@@ -120,53 +117,48 @@ pub(crate) fn pg_text_row_to_json(
let value = entries.key(column.name()); let value = entries.key(column.name());
match pg_value { match pg_value {
Some(v) if raw_output => *value = Value::String(v.to_string()), Some(v) if raw_output => value.value(v),
Some(v) => pg_text_to_json(value, v, column.type_())?, Some(v) => pg_text_to_json(value, v, column.type_())?,
None => *value = Value::Null, None => value.value(json::Null),
} }
} }
Ok(entries.finish()) entries.finish();
Ok(())
} }
// //
// Convert postgres text-encoded value to JSON value // Convert postgres text-encoded value to JSON value
// //
fn pg_text_to_json( fn pg_text_to_json(output: ValueSer, val: &str, pg_type: &Type) -> Result<(), JsonConversionError> {
output: &mut Value,
val: &str,
pg_type: &Type,
) -> Result<(), JsonConversionError> {
if let Kind::Array(elem_type) = pg_type.kind() { if let Kind::Array(elem_type) = pg_type.kind() {
// todo: we should fetch this from postgres. // todo: we should fetch this from postgres.
let delimiter = ','; let delimiter = ',';
let mut array = vec![]; json::value_as_list!(|output| pg_array_parse(output, val, elem_type, delimiter)?);
pg_array_parse(&mut array, val, elem_type, delimiter)?;
*output = Value::Array(array);
return Ok(()); return Ok(());
} }
match *pg_type { match *pg_type {
Type::BOOL => *output = Value::Bool(val == "t"), Type::BOOL => output.value(val == "t"),
Type::INT2 | Type::INT4 => { Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?; let val = val.parse::<i32>()?;
*output = Value::Number(serde_json::Number::from(val)); output.value(val);
} }
Type::FLOAT4 | Type::FLOAT8 => { Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?; let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval); if fval.is_finite() {
if let Some(num) = num { output.value(fval);
*output = Value::Number(num);
} else { } else {
// Pass Nan, Inf, -Inf as strings // Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we // JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings // want to preserve them, so we pass them as strings
*output = Value::String(val.to_string()); output.value(val);
} }
} }
Type::JSON | Type::JSONB => *output = serde_json::from_str(val)?, // we assume that the string value is valid json.
_ => *output = Value::String(val.to_string()), Type::JSON | Type::JSONB => output.write_raw_json(val.as_bytes()),
_ => output.value(val),
} }
Ok(()) Ok(())
@@ -192,7 +184,7 @@ fn pg_text_to_json(
/// gets its own level of curly braces, and delimiters must be written between adjacent /// gets its own level of curly braces, and delimiters must be written between adjacent
/// curly-braced entities of the same level. /// curly-braced entities of the same level.
fn pg_array_parse( fn pg_array_parse(
elements: &mut Vec<Value>, elements: &mut ListSer,
mut pg_array: &str, mut pg_array: &str,
elem: &Type, elem: &Type,
delim: char, delim: char,
@@ -221,7 +213,7 @@ fn pg_array_parse(
/// reads a single array from the `pg_array` string and pushes each values to `elements`. /// reads a single array from the `pg_array` string and pushes each values to `elements`.
/// returns the rest of the `pg_array` string that was not read. /// returns the rest of the `pg_array` string that was not read.
fn pg_array_parse_inner<'a>( fn pg_array_parse_inner<'a>(
elements: &mut Vec<Value>, elements: &mut ListSer,
mut pg_array: &'a str, mut pg_array: &'a str,
elem: &Type, elem: &Type,
delim: char, delim: char,
@@ -234,7 +226,7 @@ fn pg_array_parse_inner<'a>(
let mut q = String::new(); let mut q = String::new();
loop { loop {
let value = push_entry(elements, Value::Null); let value = elements.entry();
pg_array = pg_array_parse_item(value, &mut q, pg_array, elem, delim)?; pg_array = pg_array_parse_item(value, &mut q, pg_array, elem, delim)?;
// check for separator. // check for separator.
@@ -260,7 +252,7 @@ fn pg_array_parse_inner<'a>(
/// ///
/// `quoted` is a scratch allocation that has no defined output. /// `quoted` is a scratch allocation that has no defined output.
fn pg_array_parse_item<'a>( fn pg_array_parse_item<'a>(
output: &mut Value, output: ValueSer,
quoted: &mut String, quoted: &mut String,
mut pg_array: &'a str, mut pg_array: &'a str,
elem: &Type, elem: &Type,
@@ -276,9 +268,8 @@ fn pg_array_parse_item<'a>(
if pg_array.starts_with('{') { if pg_array.starts_with('{') {
// nested array. // nested array.
let mut nested = vec![]; pg_array =
pg_array = pg_array_parse_inner(&mut nested, pg_array, elem, delim)?; json::value_as_list!(|output| pg_array_parse_inner(output, pg_array, elem, delim))?;
*output = Value::Array(nested);
return Ok(pg_array); return Ok(pg_array);
} }
@@ -306,7 +297,7 @@ fn pg_array_parse_item<'a>(
// we might have an item string: // we might have an item string:
// check for null // check for null
if item == "NULL" { if item == "NULL" {
*output = Value::Null; output.value(json::Null);
} else { } else {
pg_text_to_json(output, item, elem)?; pg_text_to_json(output, item, elem)?;
} }
@@ -440,15 +431,15 @@ mod tests {
} }
fn pg_text_to_json(val: &str, pg_type: &Type) -> Value { fn pg_text_to_json(val: &str, pg_type: &Type) -> Value {
let mut v = Value::Null; let output = json::value_to_string!(|v| super::pg_text_to_json(v, val, pg_type).unwrap());
super::pg_text_to_json(&mut v, val, pg_type).unwrap(); serde_json::from_str(&output).unwrap()
v
} }
fn pg_array_parse(pg_array: &str, pg_type: &Type) -> Value { fn pg_array_parse(pg_array: &str, pg_type: &Type) -> Value {
let mut array = vec![]; let output = json::value_to_string!(|v| json::value_as_list!(|v| {
super::pg_array_parse(&mut array, pg_array, pg_type, ',').unwrap(); super::pg_array_parse(v, pg_array, pg_type, ',').unwrap();
Value::Array(array) }));
serde_json::from_str(&output).unwrap()
} }
#[test] #[test]

View File

@@ -14,10 +14,7 @@ use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header}; use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap; use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState}; use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{ use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::Value; use serde_json::Value;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use tokio::time::{self, Instant}; use tokio::time::{self, Instant};
@@ -687,32 +684,21 @@ impl QueryData {
let (inner, mut discard) = client.inner(); let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token(); let cancel_token = inner.cancel_token();
match select( let mut json_buf = vec![];
let batch_result = match select(
pin!(query_to_json( pin!(query_to_json(
config, config,
&mut *inner, &mut *inner,
self, self,
&mut 0, json::ValueSer::new(&mut json_buf),
parsed_headers parsed_headers
)), )),
pin!(cancel.cancelled()), pin!(cancel.cancelled()),
) )
.await .await
{ {
// The query successfully completed. Either::Left((res, __not_yet_cancelled)) => res,
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
discard.check_idle(status);
let json_output =
serde_json::to_string(&results).expect("json serialization should not fail");
Ok(json_output)
}
// The query failed with an error
Either::Left((Err(e), __not_yet_cancelled)) => {
discard.discard();
Err(e)
}
// The query was cancelled.
Either::Right((_cancelled, query)) => { Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query"); tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await { if let Err(err) = cancel_token.cancel_query(NoTls).await {
@@ -721,13 +707,7 @@ impl QueryData {
// wait for the query cancellation // wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), query).await { match time::timeout(time::Duration::from_millis(100), query).await {
// query successed before it was cancelled. // query successed before it was cancelled.
Ok(Ok((status, results))) => { Ok(Ok(status)) => Ok(status),
discard.check_idle(status);
let json_output = serde_json::to_string(&results)
.expect("json serialization should not fail");
Ok(json_output)
}
// query failed or was cancelled. // query failed or was cancelled.
Ok(Err(error)) => { Ok(Err(error)) => {
let db_error = match &error { let db_error = match &error {
@@ -743,14 +723,29 @@ impl QueryData {
discard.discard(); discard.discard();
} }
Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)) return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
} }
Err(_timeout) => { Err(_timeout) => {
discard.discard(); discard.discard();
Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)) return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
} }
} }
} }
};
match batch_result {
// The query successfully completed.
Ok(status) => {
discard.check_idle(status);
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
Ok(json_output)
}
// The query failed with an error
Err(e) => {
discard.discard();
Err(e)
}
} }
} }
} }
@@ -787,7 +782,7 @@ impl BatchQueryData {
}) })
.map_err(SqlOverHttpError::Postgres)?; .map_err(SqlOverHttpError::Postgres)?;
let json_output = match query_batch( let json_output = match query_batch_to_json(
config, config,
cancel.child_token(), cancel.child_token(),
&mut transaction, &mut transaction,
@@ -845,24 +840,21 @@ async fn query_batch(
transaction: &mut Transaction<'_>, transaction: &mut Transaction<'_>,
queries: BatchQueryData, queries: BatchQueryData,
parsed_headers: HttpHeaders, parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> { results: &mut json::ListSer<'_>,
let mut results = Vec::with_capacity(queries.queries.len()); ) -> Result<(), SqlOverHttpError> {
let mut current_size = 0;
for stmt in queries.queries { for stmt in queries.queries {
let query = pin!(query_to_json( let query = pin!(query_to_json(
config, config,
transaction, transaction,
stmt, stmt,
&mut current_size, results.entry(),
parsed_headers, parsed_headers,
)); ));
let cancelled = pin!(cancel.cancelled()); let cancelled = pin!(cancel.cancelled());
let res = select(query, cancelled).await; let res = select(query, cancelled).await;
match res { match res {
// TODO: maybe we should check that the transaction bit is set here // TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok((_, values)), _cancelled)) => { Either::Left((Ok(_), _cancelled)) => {}
results.push(values);
}
Either::Left((Err(e), _cancelled)) => { Either::Left((Err(e), _cancelled)) => {
return Err(e); return Err(e);
} }
@@ -872,8 +864,22 @@ async fn query_batch(
} }
} }
let results = json!({ "results": results }); Ok(())
let json_output = serde_json::to_string(&results).expect("json serialization should not fail"); }
async fn query_batch_to_json(
config: &'static HttpConfig,
cancel: CancellationToken,
tx: &mut Transaction<'_>,
queries: BatchQueryData,
headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
let results = obj.key("results");
json::value_as_list!(|results| {
query_batch(config, cancel, tx, queries, headers, results).await?;
});
}));
Ok(json_output) Ok(json_output)
} }
@@ -882,54 +888,54 @@ async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig, config: &'static HttpConfig,
client: &mut T, client: &mut T,
data: QueryData, data: QueryData,
current_size: &mut usize, output: json::ValueSer<'_>,
parsed_headers: HttpHeaders, parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> { ) -> Result<ReadyForQueryStatus, SqlOverHttpError> {
let query_start = Instant::now(); let query_start = Instant::now();
let query_params = data.params; let mut output = json::ObjectSer::new(output);
let mut row_stream = client let mut row_stream = client
.query_raw_txt(&data.query, query_params) .query_raw_txt(&data.query, data.params)
.await .await
.map_err(SqlOverHttpError::Postgres)?; .map_err(SqlOverHttpError::Postgres)?;
let query_acknowledged = Instant::now(); let query_acknowledged = Instant::now();
let columns_len = row_stream.statement.columns().len(); let mut json_fields = output.key("fields").list();
let mut fields = Vec::with_capacity(columns_len);
for c in row_stream.statement.columns() { for c in row_stream.statement.columns() {
fields.push(json!({ let json_field = json_fields.entry();
"name": c.name().to_owned(), json::value_as_object!(|json_field| {
"dataTypeID": c.type_().oid(), json_field.entry("name", c.name());
"tableID": c.table_oid(), json_field.entry("dataTypeID", c.type_().oid());
"columnID": c.column_id(), json_field.entry("tableID", c.table_oid());
"dataTypeSize": c.type_size(), json_field.entry("columnID", c.column_id());
"dataTypeModifier": c.type_modifier(), json_field.entry("dataTypeSize", c.type_size());
"format": "text", json_field.entry("dataTypeModifier", c.type_modifier());
})); json_field.entry("format", "text");
});
} }
json_fields.finish();
let raw_output = parsed_headers.raw_output;
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
let raw_output = parsed_headers.raw_output;
// Manually drain the stream into a vector to leave row_stream hanging // Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too // around to get a command tag. Also check that the response is not too
// big. // big.
let mut rows = Vec::new(); let mut rows = 0;
let mut json_rows = output.key("rows").list();
while let Some(row) = row_stream.next().await { while let Some(row) = row_stream.next().await {
let row = row.map_err(SqlOverHttpError::Postgres)?; let row = row.map_err(SqlOverHttpError::Postgres)?;
*current_size += row.body_len();
// we don't have a streaming response support yet so this is to prevent OOM // we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join) // from a malicious query (eg a cross join)
if *current_size > config.max_response_size_bytes { if json_rows.as_buffer().len() > config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge( return Err(SqlOverHttpError::ResponseTooLarge(
config.max_response_size_bytes, config.max_response_size_bytes,
)); ));
} }
let row = pg_text_row_to_json(&row, raw_output, array_mode)?; pg_text_row_to_json(json_rows.entry(), &row, raw_output, array_mode)?;
rows.push(row); rows += 1;
// assumption: parsing pg text and converting to json takes CPU time. // assumption: parsing pg text and converting to json takes CPU time.
// let's assume it is slightly expensive, so we should consume some cooperative budget. // let's assume it is slightly expensive, so we should consume some cooperative budget.
@@ -937,16 +943,14 @@ async fn query_to_json<T: GenericClient>(
// of rows and never hit the tokio mpsc for a long time (although unlikely). // of rows and never hit the tokio mpsc for a long time (although unlikely).
tokio::task::consume_budget().await; tokio::task::consume_budget().await;
} }
json_rows.finish();
let query_resp_end = Instant::now(); let query_resp_end = Instant::now();
let RowStream {
command_tag, let ready = row_stream.status;
status: ready,
..
} = row_stream;
// grab the command tag and number of rows affected // grab the command tag and number of rows affected
let command_tag = command_tag.unwrap_or_default(); let command_tag = row_stream.command_tag.unwrap_or_default();
let mut command_tag_split = command_tag.split(' '); let mut command_tag_split = command_tag.split(' ');
let command_tag_name = command_tag_split.next().unwrap_or_default(); let command_tag_name = command_tag_split.next().unwrap_or_default();
let command_tag_count = if command_tag_name == "INSERT" { let command_tag_count = if command_tag_name == "INSERT" {
@@ -959,7 +963,7 @@ async fn query_to_json<T: GenericClient>(
.and_then(|s| s.parse::<i64>().ok()); .and_then(|s| s.parse::<i64>().ok());
info!( info!(
rows = rows.len(), rows,
?ready, ?ready,
command_tag, command_tag,
acknowledgement = ?(query_acknowledged - query_start), acknowledgement = ?(query_acknowledged - query_start),
@@ -967,16 +971,12 @@ async fn query_to_json<T: GenericClient>(
"finished executing query" "finished executing query"
); );
// Resulting JSON format is based on the format of node-postgres result. output.entry("command", command_tag_name);
let results = json!({ output.entry("rowCount", command_tag_count);
"command": command_tag_name.to_string(), output.entry("rowAsArray", array_mode);
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
});
Ok((ready, results)) output.finish();
Ok(ready)
} }
enum Client { enum Client {

View File

@@ -7,8 +7,16 @@ pub async fn run_until_cancelled<F: Future>(
f: F, f: F,
cancellation_token: &CancellationToken, cancellation_token: &CancellationToken,
) -> Option<F::Output> { ) -> Option<F::Output> {
match select(pin!(f), pin!(cancellation_token.cancelled())).await { run_until(f, cancellation_token.cancelled()).await.ok()
Either::Left((f, _)) => Some(f), }
Either::Right(((), _)) => None,
/// Runs the future `f` unless interrupted by future `condition`.
pub async fn run_until<F1: Future, F2: Future>(
f: F1,
condition: F2,
) -> Result<F1::Output, F2::Output> {
match select(pin!(f), pin!(condition)).await {
Either::Left((f1, _)) => Ok(f1),
Either::Right((f2, _)) => Err(f2),
} }
} }

View File

@@ -222,6 +222,9 @@ struct Cli {
/// Primarily useful for testing to reduce test execution time. /// Primarily useful for testing to reduce test execution time.
#[arg(long, default_value = "false", action=ArgAction::Set)] #[arg(long, default_value = "false", action=ArgAction::Set)]
kick_secondary_downloads: bool, kick_secondary_downloads: bool,
#[arg(long)]
shard_split_request_timeout: Option<humantime::Duration>,
} }
enum StrictMode { enum StrictMode {
@@ -470,6 +473,10 @@ async fn async_main() -> anyhow::Result<()> {
timeline_safekeeper_count: args.timeline_safekeeper_count, timeline_safekeeper_count: args.timeline_safekeeper_count,
posthog_config: posthog_config.clone(), posthog_config: posthog_config.clone(),
kick_secondary_downloads: args.kick_secondary_downloads, kick_secondary_downloads: args.kick_secondary_downloads,
shard_split_request_timeout: args
.shard_split_request_timeout
.map(humantime::Duration::into)
.unwrap_or(Duration::MAX),
}; };
// Validate that we can connect to the database // Validate that we can connect to the database

View File

@@ -60,6 +60,7 @@ use tokio::sync::mpsc::error::TrySendError;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn}; use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use utils::completion::Barrier; use utils::completion::Barrier;
use utils::env;
use utils::generation::Generation; use utils::generation::Generation;
use utils::id::{NodeId, TenantId, TimelineId}; use utils::id::{NodeId, TenantId, TimelineId};
use utils::lsn::Lsn; use utils::lsn::Lsn;
@@ -483,6 +484,9 @@ pub struct Config {
/// When set, actively checks and initiates heatmap downloads/uploads. /// When set, actively checks and initiates heatmap downloads/uploads.
pub kick_secondary_downloads: bool, pub kick_secondary_downloads: bool,
/// Timeout used for HTTP client of split requests. [`Duration::MAX`] if None.
pub shard_split_request_timeout: Duration,
} }
impl From<DatabaseError> for ApiError { impl From<DatabaseError> for ApiError {
@@ -5206,6 +5210,9 @@ impl Service {
match res { match res {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(mgmt_api::Error::ApiError(StatusCode::CONFLICT, _)) => Ok(StatusCode::CONFLICT), Err(mgmt_api::Error::ApiError(StatusCode::CONFLICT, _)) => Ok(StatusCode::CONFLICT),
Err(mgmt_api::Error::ApiError(StatusCode::PRECONDITION_FAILED, msg)) if msg.contains("Requested tenant is missing") => {
Err(ApiError::ResourceUnavailable("Tenant migration in progress".into()))
},
Err(mgmt_api::Error::ApiError(StatusCode::SERVICE_UNAVAILABLE, msg)) => Err(ApiError::ResourceUnavailable(msg.into())), Err(mgmt_api::Error::ApiError(StatusCode::SERVICE_UNAVAILABLE, msg)) => Err(ApiError::ResourceUnavailable(msg.into())),
Err(e) => { Err(e) => {
Err( Err(
@@ -6403,18 +6410,39 @@ impl Service {
// TODO: issue split calls concurrently (this only matters once we're splitting // TODO: issue split calls concurrently (this only matters once we're splitting
// N>1 shards into M shards -- initially we're usually splitting 1 shard into N). // N>1 shards into M shards -- initially we're usually splitting 1 shard into N).
// HADRON: set a timeout for splitting individual shards on page servers.
// Currently we do not perform any retry because it's not clear if page server can handle
// partially split shards correctly.
let shard_split_timeout =
if let Some(env::DeploymentMode::Local) = env::get_deployment_mode() {
Duration::from_secs(30)
} else {
self.config.shard_split_request_timeout
};
let mut http_client_builder = reqwest::ClientBuilder::new()
.pool_max_idle_per_host(0)
.timeout(shard_split_timeout);
for ssl_ca_cert in &self.config.ssl_ca_certs {
http_client_builder = http_client_builder.add_root_certificate(ssl_ca_cert.clone());
}
let http_client = http_client_builder
.build()
.expect("Failed to construct HTTP client");
for target in &targets { for target in &targets {
let ShardSplitTarget { let ShardSplitTarget {
parent_id, parent_id,
node, node,
child_ids, child_ids,
} = target; } = target;
let client = PageserverClient::new( let client = PageserverClient::new(
node.get_id(), node.get_id(),
self.http_client.clone(), http_client.clone(),
node.base_url(), node.base_url(),
self.config.pageserver_jwt_token.as_deref(), self.config.pageserver_jwt_token.as_deref(),
); );
let response = client let response = client
.tenant_shard_split( .tenant_shard_split(
*parent_id, *parent_id,

View File

@@ -39,13 +39,13 @@ use utils::lsn::Lsn;
use super::Service; use super::Service;
impl Service { impl Service {
fn make_member_set(safekeepers: &[Safekeeper]) -> Result<MemberSet, ApiError> { fn make_member_set(safekeepers: &[Safekeeper]) -> Result<MemberSet, anyhow::Error> {
let members = safekeepers let members = safekeepers
.iter() .iter()
.map(|sk| sk.get_safekeeper_id()) .map(|sk| sk.get_safekeeper_id())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
MemberSet::new(members).map_err(ApiError::InternalServerError) MemberSet::new(members)
} }
fn get_safekeepers(&self, ids: &[i64]) -> Result<Vec<Safekeeper>, ApiError> { fn get_safekeepers(&self, ids: &[i64]) -> Result<Vec<Safekeeper>, ApiError> {
@@ -80,7 +80,7 @@ impl Service {
) -> Result<Vec<NodeId>, ApiError> { ) -> Result<Vec<NodeId>, ApiError> {
let safekeepers = self.get_safekeepers(&timeline_persistence.sk_set)?; let safekeepers = self.get_safekeepers(&timeline_persistence.sk_set)?;
let mset = Self::make_member_set(&safekeepers)?; let mset = Self::make_member_set(&safekeepers).map_err(ApiError::InternalServerError)?;
let mconf = safekeeper_api::membership::Configuration::new(mset); let mconf = safekeeper_api::membership::Configuration::new(mset);
let req = safekeeper_api::models::TimelineCreateRequest { let req = safekeeper_api::models::TimelineCreateRequest {
@@ -1105,6 +1105,26 @@ impl Service {
} }
} }
if new_sk_set.is_empty() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"new safekeeper set is empty"
)));
}
if new_sk_set.len() < self.config.timeline_safekeeper_count {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"new safekeeper set must have at least {} safekeepers",
self.config.timeline_safekeeper_count
)));
}
let new_sk_set_i64 = new_sk_set.iter().map(|id| id.0 as i64).collect::<Vec<_>>();
let new_safekeepers = self.get_safekeepers(&new_sk_set_i64)?;
// Construct new member set in advance to validate it.
// E.g. validates that there is no duplicate safekeepers.
let new_sk_member_set =
Self::make_member_set(&new_safekeepers).map_err(ApiError::BadRequest)?;
// TODO(diko): per-tenant lock is too wide. Consider introducing per-timeline locks. // TODO(diko): per-tenant lock is too wide. Consider introducing per-timeline locks.
let _tenant_lock = trace_shared_lock( let _tenant_lock = trace_shared_lock(
&self.tenant_op_locks, &self.tenant_op_locks,
@@ -1135,6 +1155,18 @@ impl Service {
.map(|&id| NodeId(id as u64)) .map(|&id| NodeId(id as u64))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Validate that we are not migrating to a decomissioned safekeeper.
for sk in new_safekeepers.iter() {
if !cur_sk_set.contains(&sk.get_id())
&& sk.scheduling_policy() == SkSchedulingPolicy::Decomissioned
{
return Err(ApiError::BadRequest(anyhow::anyhow!(
"safekeeper {} is decomissioned",
sk.get_id()
)));
}
}
tracing::info!( tracing::info!(
?cur_sk_set, ?cur_sk_set,
?new_sk_set, ?new_sk_set,
@@ -1177,11 +1209,8 @@ impl Service {
} }
let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?; let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?;
let cur_sk_member_set = Self::make_member_set(&cur_safekeepers)?; let cur_sk_member_set =
Self::make_member_set(&cur_safekeepers).map_err(ApiError::InternalServerError)?;
let new_sk_set_i64 = new_sk_set.iter().map(|id| id.0 as i64).collect::<Vec<_>>();
let new_safekeepers = self.get_safekeepers(&new_sk_set_i64)?;
let new_sk_member_set = Self::make_member_set(&new_safekeepers)?;
let joint_config = membership::Configuration { let joint_config = membership::Configuration {
generation, generation,

View File

@@ -5421,6 +5421,7 @@ SKIP_FILES = frozenset(
( (
"pg_internal.init", "pg_internal.init",
"pg.log", "pg.log",
"neon.signal",
"zenith.signal", "zenith.signal",
"pg_hba.conf", "pg_hba.conf",
"postgresql.conf", "postgresql.conf",

View File

@@ -115,8 +115,7 @@ DEFAULT_PAGESERVER_ALLOWED_ERRORS = (
".*Local data loss suspected.*", ".*Local data loss suspected.*",
# Too many frozen layers error is normal during intensive benchmarks # Too many frozen layers error is normal during intensive benchmarks
".*too many frozen layers.*", ".*too many frozen layers.*",
# Transient errors when resolving tenant shards by page service ".*Failed to resolve tenant shard after.*",
".*Fail to resolve tenant shard in attempt.*",
# Expected warnings when pageserver has not refreshed GC info yet # Expected warnings when pageserver has not refreshed GC info yet
".*pitr LSN/interval not found, skipping force image creation LSN calculation.*", ".*pitr LSN/interval not found, skipping force image creation LSN calculation.*",
".*No broker updates received for a while.*", ".*No broker updates received for a while.*",

View File

@@ -7,6 +7,7 @@ import time
from enum import StrEnum from enum import StrEnum
import pytest import pytest
from fixtures.common_types import TenantShardId
from fixtures.log_helper import log from fixtures.log_helper import log
from fixtures.neon_fixtures import ( from fixtures.neon_fixtures import (
NeonEnvBuilder, NeonEnvBuilder,
@@ -960,6 +961,67 @@ def get_layer_map(env, tenant_shard_id, timeline_id, ps_id):
return image_layer_count, delta_layer_count return image_layer_count, delta_layer_count
def test_image_layer_creation_time_threshold(neon_env_builder: NeonEnvBuilder):
"""
Tests that image layers can be created when the time threshold is reached on non-0 shards.
"""
tenant_conf = {
"compaction_threshold": "100",
"image_creation_threshold": "100",
"image_layer_creation_check_threshold": "1",
# disable distance based image layer creation check
"checkpoint_distance": 10 * 1024 * 1024 * 1024,
"checkpoint_timeout": "100ms",
"image_layer_force_creation_period": "1s",
"pitr_interval": "10s",
"gc_period": "1s",
"compaction_period": "1s",
"lsn_lease_length": "1s",
}
# consider every tenant large to run the image layer generation check more eagerly
neon_env_builder.pageserver_config_override = (
"image_layer_generation_large_timeline_threshold=0"
)
neon_env_builder.num_pageservers = 1
neon_env_builder.num_safekeepers = 1
env = neon_env_builder.init_start(
initial_tenant_conf=tenant_conf,
initial_tenant_shard_count=2,
initial_tenant_shard_stripe_size=1,
)
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
endpoint = env.endpoints.create_start("main")
endpoint.safe_psql("CREATE TABLE foo (id INTEGER, val text)")
for v in range(10):
endpoint.safe_psql(f"INSERT INTO foo (id, val) VALUES ({v}, repeat('abcde{v:0>3}', 500))")
tenant_shard_id = TenantShardId(tenant_id, 1, 2)
# Generate some rows.
for v in range(20):
endpoint.safe_psql(f"INSERT INTO foo (id, val) VALUES ({v}, repeat('abcde{v:0>3}', 500))")
# restart page server so that logical size on non-0 shards is missing
env.pageserver.restart()
(old_images, old_deltas) = get_layer_map(env, tenant_shard_id, timeline_id, 0)
log.info(f"old images: {old_images}, old deltas: {old_deltas}")
def check_image_creation():
(new_images, old_deltas) = get_layer_map(env, tenant_shard_id, timeline_id, 0)
log.info(f"images: {new_images}, deltas: {old_deltas}")
assert new_images > old_images
wait_until(check_image_creation)
endpoint.stop_and_destroy()
def test_image_layer_force_creation_period(neon_env_builder: NeonEnvBuilder): def test_image_layer_force_creation_period(neon_env_builder: NeonEnvBuilder):
""" """
Tests that page server can force creating new images if image_layer_force_creation_period is enabled Tests that page server can force creating new images if image_layer_force_creation_period is enabled

View File

@@ -2,6 +2,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pytest
from fixtures.neon_fixtures import StorageControllerApiException
if TYPE_CHECKING: if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnvBuilder from fixtures.neon_fixtures import NeonEnvBuilder
@@ -75,3 +78,38 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder):
ep.start(safekeeper_generation=1, safekeepers=[3]) ep.start(safekeeper_generation=1, safekeepers=[3])
assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(1, 4)] assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(1, 4)]
def test_new_sk_set_validation(neon_env_builder: NeonEnvBuilder):
"""
Test that safekeeper_migrate validates the new_sk_set before starting the migration.
"""
neon_env_builder.num_safekeepers = 3
neon_env_builder.storage_controller_config = {
"timelines_onto_safekeepers": True,
"timeline_safekeeper_count": 2,
}
env = neon_env_builder.init_start()
def expect_fail(sk_set: list[int], match: str):
with pytest.raises(StorageControllerApiException, match=match):
env.storage_controller.migrate_safekeepers(
env.initial_tenant, env.initial_timeline, sk_set
)
# Check that we failed before commiting to the database.
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
assert mconf["generation"] == 1
expect_fail([], "safekeeper set is empty")
expect_fail([1], "must have at least 2 safekeepers")
expect_fail([1, 1], "duplicate safekeeper")
expect_fail([1, 100500], "does not exist")
mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline)
sk_set = mconf["sk_set"]
assert len(sk_set) == 2
decom_sk = [sk.id for sk in env.safekeepers if sk.id not in sk_set][0]
env.storage_controller.safekeeper_scheduling_policy(decom_sk, "Decomissioned")
expect_fail([sk_set[0], decom_sk], "decomissioned")

View File

@@ -1673,6 +1673,91 @@ def test_shard_resolve_during_split_abort(neon_env_builder: NeonEnvBuilder):
# END_HADRON # END_HADRON
# HADRON
@pytest.mark.skip(reason="The backpressure change has not been merged yet.")
def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder):
"""
Tests back pressure knobs are enforced on the per shard basis instead of at the tenant level.
"""
init_shard_count = 4
neon_env_builder.num_pageservers = init_shard_count
stripe_size = 1
env = neon_env_builder.init_start(
initial_tenant_shard_count=init_shard_count,
initial_tenant_shard_stripe_size=stripe_size,
initial_tenant_conf={
# disable auto-flush of shards and set max_replication_flush_lag as 15MB.
# The backpressure parameters must be enforced at the shard level to avoid stalling PG.
"checkpoint_distance": 1 * 1024 * 1024 * 1024,
"checkpoint_timeout": "1h",
},
)
endpoint = env.endpoints.create(
"main",
config_lines=[
"max_replication_write_lag = 0",
"max_replication_apply_lag = 0",
"max_replication_flush_lag = 15MB",
"neon.max_cluster_size = 10GB",
],
)
endpoint.respec(skip_pg_catalog_updates=False) # Needed for databricks_system to get created.
endpoint.start()
# generate 20MB of data
endpoint.safe_psql(
"CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 20000) s;"
)
res = endpoint.safe_psql(
"SELECT neon.backpressure_throttling_time() as throttling_time", dbname="databricks_system"
)[0]
assert res[0] == 0, f"throttling_time should be 0, but got {res[0]}"
endpoint.stop()
# HADRON
def test_shard_split_page_server_timeout(neon_env_builder: NeonEnvBuilder):
"""
Tests that shard split can correctly handle page server timeouts and abort the split
"""
init_shard_count = 2
neon_env_builder.num_pageservers = 1
stripe_size = 1
if neon_env_builder.storage_controller_config is None:
neon_env_builder.storage_controller_config = {"shard_split_request_timeout": "5s"}
else:
neon_env_builder.storage_controller_config["shard_split_request_timeout"] = "5s"
env = neon_env_builder.init_start(
initial_tenant_shard_count=init_shard_count,
initial_tenant_shard_stripe_size=stripe_size,
)
env.storage_controller.allowed_errors.extend(
[
".*Enqueuing background abort.*",
".*failpoint.*",
".*Failed to abort.*",
".*Exclusive lock by ShardSplit was held.*",
]
)
env.pageserver.allowed_errors.extend([".*request was dropped before completing.*"])
endpoint1 = env.endpoints.create_start(branch_name="main")
env.pageserver.http_client().configure_failpoints(("shard-split-post-finish-pause", "pause"))
with pytest.raises(StorageControllerApiException):
env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=4)
env.pageserver.http_client().configure_failpoints(("shard-split-post-finish-pause", "off"))
endpoint1.stop_and_destroy()
def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder):
""" """
Check a scenario when one of the shards is much slower than others. Check a scenario when one of the shards is much slower than others.

View File

@@ -209,9 +209,9 @@ def test_ancestor_detach_branched_from(
client.timeline_delete(env.initial_tenant, env.initial_timeline) client.timeline_delete(env.initial_tenant, env.initial_timeline)
wait_timeline_detail_404(client, env.initial_tenant, env.initial_timeline) wait_timeline_detail_404(client, env.initial_tenant, env.initial_timeline)
# because we do the fullbackup from ancestor at the branch_lsn, the zenith.signal is always different # because we do the fullbackup from ancestor at the branch_lsn, the neon.signal and/or zenith.signal is always
# as there is always "PREV_LSN: invalid" for "before" # different as there is always "PREV_LSN: invalid" for "before"
skip_files = {"zenith.signal"} skip_files = {"zenith.signal", "neon.signal"}
assert_pageserver_backups_equal(fullbackup_before, fullbackup_after, skip_files) assert_pageserver_backups_equal(fullbackup_before, fullbackup_after, skip_files)
@@ -767,7 +767,7 @@ def test_compaction_induced_by_detaches_in_history(
env.pageserver, env.initial_tenant, branch_timeline_id, branch_lsn, fullbackup_after env.pageserver, env.initial_tenant, branch_timeline_id, branch_lsn, fullbackup_after
) )
# we don't need to skip any files, because zenith.signal will be identical # we don't need to skip any files, because neon.signal will be identical
assert_pageserver_backups_equal(fullbackup_before, fullbackup_after, set()) assert_pageserver_backups_equal(fullbackup_before, fullbackup_after, set())

View File

@@ -1,18 +1,18 @@
{ {
"v17": [ "v17": [
"17.5", "17.5",
"db424d42d748f8ad91ac00e28db2c7f2efa42f7f" "353c725b0c76cc82b15af21d8360d03391dc6814"
], ],
"v16": [ "v16": [
"16.9", "16.9",
"7a4c0eacaeb9b97416542fa19103061c166460b1" "e08c8d5f1576ca0487d14d154510499c5f12adfb"
], ],
"v15": [ "v15": [
"15.13", "15.13",
"8c3249f36c7df6ac0efb8ee9f1baf4aa1b83e5c9" "afd46987f3da50c9146a8aa59380052df0862c06"
], ],
"v14": [ "v14": [
"14.18", "14.18",
"9085654ee8022d5cc4ca719380a1dc53e5e3246f" "8ce1f52303aec29e098309347b57c01a1962e221"
] ]
} }